Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import tempfile | |
| import subprocess | |
| from pathlib import Path | |
| import torch | |
| torch.set_num_threads(1) | |
| import torchaudio | |
| import soundfile as sf | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| # NOTE: we lazy-load these inside get_model() | |
| processor = None | |
| model = None | |
| TARGET_SR = 16000 # wav2vec2 expects 16 kHz | |
| def get_model(): | |
| """ | |
| Lazily load processor and model on first call and cache them globally. | |
| Uses a custom HF cache dir to avoid permission issues on Hugging Face Spaces. | |
| """ | |
| global processor, model | |
| if processor is None or model is None: | |
| print("π Loading HF processor & model (this may take 10β60s on first request)...") | |
| from transformers import Wav2Vec2Processor, AutoModelForAudioClassification | |
| cache_dir = os.getenv("HF_HOME", "/app/hf_cache") | |
| processor = Wav2Vec2Processor.from_pretrained( | |
| "facebook/wav2vec2-base-960h", | |
| cache_dir=cache_dir | |
| ) | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| "prithivMLmods/Common-Voice-Gender-Detection", | |
| cache_dir=cache_dir | |
| ) | |
| model.eval() | |
| print("β Model & processor loaded.") | |
| return processor, model | |
| app = FastAPI(title="Gender Detection API (lazy model load)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def home(): | |
| return """ | |
| <html> | |
| <body> | |
| <h2>Upload Audio for Gender Detection</h2> | |
| <form action="/predict" enctype="multipart/form-data" method="post"> | |
| <input name="file" type="file" accept=".wav,.mp3,.flac,.ogg" /> | |
| <input type="submit" value="Upload" /> | |
| </form> | |
| <p>POST /predict (multipart form-data, field name "file")</p> | |
| </body> | |
| </html> | |
| """ | |
| async def health(): | |
| return {"status": "ok"} | |
| async def labels(): | |
| proc, mdl = get_model() | |
| return mdl.config.id2label | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| proc, mdl = get_model() | |
| # Save upload to a temporary file | |
| suffix = Path(file.filename or "").suffix or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| raw = await file.read() | |
| tmp.write(raw) | |
| tmp_path = tmp.name | |
| try: | |
| # Try to read using soundfile (libsndfile) | |
| try: | |
| waveform_np, sr = sf.read(tmp_path, dtype="float32") | |
| except Exception as e: | |
| # If soundfile fails, convert with ffmpeg then read | |
| print("β οΈ soundfile could not read directly, trying ffmpeg conversion:", e) | |
| converted = tmp_path + ".converted.wav" | |
| ffmpeg_cmd = [ | |
| "ffmpeg", "-y", "-i", tmp_path, | |
| "-ar", str(TARGET_SR), "-ac", "1", converted | |
| ] | |
| subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False) | |
| waveform_np, sr = sf.read(converted, dtype="float32") | |
| try: | |
| os.unlink(converted) | |
| except Exception: | |
| pass | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except Exception: | |
| pass | |
| if waveform_np.ndim > 1: | |
| waveform_np = waveform_np.mean(axis=1) | |
| waveform = torch.tensor(waveform_np, dtype=torch.float32).unsqueeze(0) | |
| if sr != TARGET_SR: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR) | |
| waveform = resampler(waveform) | |
| sr = TARGET_SR | |
| inputs = proc( | |
| waveform.squeeze().numpy(), | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| logits = mdl(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| labels_map = mdl.config.id2label | |
| result = {labels_map[i]: float(probs[i]) for i in range(len(labels_map))} | |
| top_idx = int(probs.argmax()) | |
| return JSONResponse(content={"top": labels_map[top_idx], "scores": result}) | |
| except Exception as e: | |
| import traceback | |
| print("π₯ Error in /predict:", e) | |
| traceback.print_exc() | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| print(f"π Starting app on port {port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |