# app.py import os import tempfile import subprocess from pathlib import Path import torch torch.set_num_threads(1) import torchaudio import soundfile as sf import numpy as np from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, HTMLResponse # NOTE: we lazy-load these inside get_model() processor = None model = None TARGET_SR = 16000 # wav2vec2 expects 16 kHz def get_model(): """ Lazily load processor and model on first call and cache them globally. Uses a custom HF cache dir to avoid permission issues on Hugging Face Spaces. """ global processor, model if processor is None or model is None: print("🔁 Loading HF processor & model (this may take 10–60s on first request)...") from transformers import Wav2Vec2Processor, AutoModelForAudioClassification cache_dir = os.getenv("HF_HOME", "/app/hf_cache") processor = Wav2Vec2Processor.from_pretrained( "facebook/wav2vec2-base-960h", cache_dir=cache_dir ) model = AutoModelForAudioClassification.from_pretrained( "prithivMLmods/Common-Voice-Gender-Detection", cache_dir=cache_dir ) model.eval() print("✅ Model & processor loaded.") return processor, model app = FastAPI(title="Gender Detection API (lazy model load)") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_class=HTMLResponse) async def home(): return """
POST /predict (multipart form-data, field name "file")
""" @app.get("/health") async def health(): return {"status": "ok"} @app.get("/labels") async def labels(): proc, mdl = get_model() return mdl.config.id2label @app.post("/predict") async def predict(file: UploadFile = File(...)): try: proc, mdl = get_model() # Save upload to a temporary file suffix = Path(file.filename or "").suffix or ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: raw = await file.read() tmp.write(raw) tmp_path = tmp.name try: # Try to read using soundfile (libsndfile) try: waveform_np, sr = sf.read(tmp_path, dtype="float32") except Exception as e: # If soundfile fails, convert with ffmpeg then read print("⚠️ soundfile could not read directly, trying ffmpeg conversion:", e) converted = tmp_path + ".converted.wav" ffmpeg_cmd = [ "ffmpeg", "-y", "-i", tmp_path, "-ar", str(TARGET_SR), "-ac", "1", converted ] subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) waveform_np, sr = sf.read(converted, dtype="float32") try: os.unlink(converted) except Exception: pass finally: try: os.unlink(tmp_path) except Exception: pass if waveform_np.ndim > 1: waveform_np = waveform_np.mean(axis=1) waveform = torch.tensor(waveform_np, dtype=torch.float32).unsqueeze(0) if sr != TARGET_SR: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR) waveform = resampler(waveform) sr = TARGET_SR inputs = proc( waveform.squeeze().numpy(), sampling_rate=sr, return_tensors="pt", padding=True, ) with torch.no_grad(): logits = mdl(**inputs).logits probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] labels_map = mdl.config.id2label result = {labels_map[i]: float(probs[i]) for i in range(len(labels_map))} top_idx = int(probs.argmax()) return JSONResponse(content={"top": labels_map[top_idx], "scores": result}) except Exception as e: import traceback print("🔥 Error in /predict:", e) traceback.print_exc() return JSONResponse(status_code=400, content={"error": str(e)}) if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 8000)) print(f"🚀 Starting app on port {port}") uvicorn.run(app, host="0.0.0.0", port=port)