Spaces:
Running
Running
File size: 2,378 Bytes
126f215 45c6c27 126f215 031f538 126f215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import tempfile
import torch
import uvicorn
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydub import AudioSegment
from src.config.config import DatasetConfig
from src.models.predict import AudioPredictor
dataset_cfg = DatasetConfig()
app = FastAPI(
title="ESC50 Audio Classifier API",
description="API for environmental sound classification",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = AudioPredictor("final_model.pt", device=device)
@app.get("/")
async def root():
return FileResponse("index.html")
@app.get("/labels")
def get_labels():
return {"labels": DatasetConfig().esc50_labels}
@app.get("/api/status")
async def status():
return {
"status": "running"
}
@app.post("/predict-top-k")
async def predict_top_k(file: UploadFile = File(...), k: int = 5):
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
suffix = os.path.splitext(file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
wav_path = tempfile.mktemp(suffix=".wav")
print("[1] Converting to wav...")
AudioSegment.from_file(tmp_path).export(wav_path, format="wav")
print("[2] Running inference...")
predicted_class, top_probs, top_indices = predictor.predict_file(wav_path, top_k=k)
print(f"[3] Done: {predicted_class} = {dataset_cfg.esc50_labels[predicted_class]}")
return {
"predicted_class": dataset_cfg.esc50_labels[predicted_class],
"confidence": float(top_probs[0]),
"top_predictions": [
{"class": dataset_cfg.esc50_labels[idx], "confidence": float(prob)}
for prob, idx in zip(top_probs, top_indices)
],
}
finally:
os.unlink(tmp_path)
if os.path.exists(wav_path):
os.unlink(wav_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |