File size: 2,143 Bytes
05b56c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import tempfile, shutil
from app.preprocess import preprocess_audio
from app.model import predict
import numpy as np
app = FastAPI(title="General Audio Classifier")
@app.post("/predict")
async def predict_audio(file: UploadFile = File(...)):
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name
# Preprocess → multiple spectrograms
imgs = preprocess_audio(tmp_path)
# Predict all chunks
all_preds = []
all_confidences = []
for img in imgs:
label, confidence, probs = predict(img)
all_preds.append(label)
all_confidences.append(confidence)
# Combine predictions (majority vote with confidence tiebreaker)
from collections import Counter, defaultdict
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
# Tie-breaker using sum of confidences
confidence_sums = defaultdict(float)
for i, label in enumerate(all_preds):
if label in candidates:
confidence_sums[label] += all_confidences[i]
final_label = max(confidence_sums, key=confidence_sums.get)
# Average confidence for final label
final_confidence = np.mean([all_confidences[i] for i, label in enumerate(all_preds) if label == final_label])
return JSONResponse(content={
"predicted_label": final_label,
"confidence": round(final_confidence, 3),
"all_predictions": all_preds,
"all_confidences": [round(c,3) for c in all_confidences]
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
|