Midnightar commited on
Commit
75fe029
·
verified ·
1 Parent(s): 61102eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import tempfile
4
+ import subprocess
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ # Limit PyTorch threads to reduce memory/CPU pressure on small containers
9
+ torch.set_num_threads(1)
10
+
11
+ import torchaudio
12
+ import soundfile as sf
13
+ import numpy as np
14
+
15
+ from fastapi import FastAPI, File, UploadFile
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import JSONResponse, HTMLResponse
18
+
19
+ # NOTE: we lazy-load these inside get_model()
20
+ processor = None
21
+ model = None
22
+
23
+ TARGET_SR = 16000 # wav2vec2 expects 16 kHz
24
+
25
+ def get_model():
26
+ """
27
+ Lazily load processor and model on first call and cache them globally.
28
+ Call inside request handlers to avoid heavy startup on cold starts.
29
+ """
30
+ global processor, model
31
+ if processor is None or model is None:
32
+ print("🔁 Loading HF processor & model (this may take 10-60s on first request)...")
33
+ from transformers import Wav2Vec2Processor, AutoModelForAudioClassification
34
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
35
+ model = AutoModelForAudioClassification.from_pretrained(
36
+ "prithivMLmods/Common-Voice-Gender-Detection"
37
+ )
38
+ model.eval()
39
+ print("✅ Model & processor loaded.")
40
+ return processor, model
41
+
42
+
43
+ app = FastAPI(title="Gender Detection API (lazy model load)")
44
+
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"],
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+
54
+ @app.get("/", response_class=HTMLResponse)
55
+ async def home():
56
+ return """
57
+ <html>
58
+ <body>
59
+ <h2>Upload Audio for Gender Detection</h2>
60
+ <form action="/predict" enctype="multipart/form-data" method="post">
61
+ <input name="file" type="file" accept=".wav,.mp3,.flac,.ogg" />
62
+ <input type="submit" value="Upload" />
63
+ </form>
64
+ <p>POST /predict (multipart form-data, field name "file")</p>
65
+ </body>
66
+ </html>
67
+ """
68
+
69
+
70
+ @app.get("/health")
71
+ async def health():
72
+ return {"status": "ok"}
73
+
74
+
75
+ @app.get("/labels")
76
+ async def labels():
77
+ proc, mdl = get_model()
78
+ return mdl.config.id2label
79
+
80
+
81
+ @app.post("/predict")
82
+ async def predict(file: UploadFile = File(...)):
83
+ try:
84
+ proc, mdl = get_model()
85
+
86
+ # Save upload to a temporary file
87
+ suffix = Path(file.filename or "").suffix or ".wav"
88
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
89
+ raw = await file.read()
90
+ tmp.write(raw)
91
+ tmp_path = tmp.name
92
+
93
+ try:
94
+ # Try to read using soundfile (libsndfile)
95
+ try:
96
+ waveform_np, sr = sf.read(tmp_path, dtype="float32")
97
+ except Exception as e:
98
+ # If soundfile fails (some mp3/ogg), try using ffmpeg to convert to WAV then read
99
+ print("⚠️ soundfile could not read directly, trying ffmpeg conversion:", e)
100
+ converted = tmp_path + ".converted.wav"
101
+ # Use ffmpeg CLI (ffmpeg must be installed in the container)
102
+ ffmpeg_cmd = [
103
+ "ffmpeg", "-y", "-i", tmp_path,
104
+ "-ar", str(TARGET_SR), "-ac", "1", converted
105
+ ]
106
+ subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False)
107
+ waveform_np, sr = sf.read(converted, dtype="float32")
108
+ try:
109
+ os.unlink(converted)
110
+ except Exception:
111
+ pass
112
+
113
+ finally:
114
+ # remove uploaded tmp file as soon as possible
115
+ try:
116
+ os.unlink(tmp_path)
117
+ except Exception:
118
+ pass
119
+
120
+ # waveform_np shape: (n_samples,) or (n_samples, channels)
121
+ if waveform_np.ndim > 1:
122
+ # average channels to mono
123
+ waveform_np = waveform_np.mean(axis=1)
124
+
125
+ # Convert to torch tensor shape [1, n_samples]
126
+ waveform = torch.tensor(waveform_np, dtype=torch.float32).unsqueeze(0)
127
+
128
+ # Resample if necessary
129
+ if sr != TARGET_SR:
130
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
131
+ waveform = resampler(waveform)
132
+ sr = TARGET_SR
133
+
134
+ # Prepare inputs for HF model
135
+ inputs = proc(
136
+ waveform.squeeze().numpy(),
137
+ sampling_rate=sr,
138
+ return_tensors="pt",
139
+ padding=True,
140
+ )
141
+
142
+ with torch.no_grad():
143
+ logits = mdl(**inputs).logits
144
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
145
+
146
+ labels_map = mdl.config.id2label
147
+ result = {labels_map[i]: float(probs[i]) for i in range(len(labels_map))}
148
+ top_idx = int(probs.argmax())
149
+
150
+ return JSONResponse(content={"top": labels_map[top_idx], "scores": result})
151
+
152
+ except Exception as e:
153
+ import traceback
154
+ print("🔥 Error in /predict:", e)
155
+ traceback.print_exc()
156
+ # Return the error string (400) so client can see the reason
157
+ return JSONResponse(status_code=400, content={"error": str(e)})
158
+
159
+
160
+ if __name__ == "__main__":
161
+ # Local dev fallback (Railway/Gunicorn uses CMD from Dockerfile)
162
+ import uvicorn
163
+ port = int(os.environ.get("PORT", 8000))
164
+ uvicorn.run(app, host="0.0.0.0", port=port)