File size: 2,378 Bytes
126f215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c6c27
126f215
 
 
 
 
 
031f538
 
 
 
126f215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import tempfile

import torch
import uvicorn
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydub import AudioSegment

from src.config.config import DatasetConfig
from src.models.predict import AudioPredictor

dataset_cfg = DatasetConfig()

app = FastAPI(
    title="ESC50 Audio Classifier API",
    description="API for environmental sound classification",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = AudioPredictor("final_model.pt", device=device)


@app.get("/")
async def root():
    return FileResponse("index.html")

@app.get("/labels")
def get_labels():
    return {"labels": DatasetConfig().esc50_labels}

@app.get("/api/status")
async def status():
    return {
        "status": "running"
    }

@app.post("/predict-top-k")
async def predict_top_k(file: UploadFile = File(...), k: int = 5):
    if predictor is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    suffix = os.path.splitext(file.filename)[1]

    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read())
        tmp_path = tmp.name

    try:
        wav_path = tempfile.mktemp(suffix=".wav")

        print("[1] Converting to wav...")
        AudioSegment.from_file(tmp_path).export(wav_path, format="wav")
        print("[2] Running inference...")
        predicted_class, top_probs, top_indices = predictor.predict_file(wav_path, top_k=k)
        print(f"[3] Done: {predicted_class} = {dataset_cfg.esc50_labels[predicted_class]}")

        return {
            "predicted_class": dataset_cfg.esc50_labels[predicted_class],
            "confidence": float(top_probs[0]),
            "top_predictions": [
                {"class": dataset_cfg.esc50_labels[idx], "confidence": float(prob)}
                for prob, idx in zip(top_probs, top_indices)
            ],
        }
    finally:
        os.unlink(tmp_path)
        if os.path.exists(wav_path):
            os.unlink(wav_path)


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")