Midnightar commited on
Commit
0209771
Β·
verified Β·
1 Parent(s): acc7904

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -5,7 +5,6 @@ import subprocess
5
  from pathlib import Path
6
 
7
  import torch
8
- # Limit PyTorch threads to reduce memory/CPU pressure on small containers
9
  torch.set_num_threads(1)
10
 
11
  import torchaudio
@@ -25,15 +24,22 @@ TARGET_SR = 16000 # wav2vec2 expects 16 kHz
25
  def get_model():
26
  """
27
  Lazily load processor and model on first call and cache them globally.
28
- Call inside request handlers to avoid heavy startup on cold starts.
29
  """
30
  global processor, model
31
  if processor is None or model is None:
32
- print("πŸ” Loading HF processor & model (this may take 10-60s on first request)...")
33
  from transformers import Wav2Vec2Processor, AutoModelForAudioClassification
34
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
 
 
 
 
 
 
35
  model = AutoModelForAudioClassification.from_pretrained(
36
- "prithivMLmods/Common-Voice-Gender-Detection"
 
37
  )
38
  model.eval()
39
  print("βœ… Model & processor loaded.")
@@ -95,10 +101,9 @@ async def predict(file: UploadFile = File(...)):
95
  try:
96
  waveform_np, sr = sf.read(tmp_path, dtype="float32")
97
  except Exception as e:
98
- # If soundfile fails (some mp3/ogg), try using ffmpeg to convert to WAV then read
99
  print("⚠️ soundfile could not read directly, trying ffmpeg conversion:", e)
100
  converted = tmp_path + ".converted.wav"
101
- # Use ffmpeg CLI (ffmpeg must be installed in the container)
102
  ffmpeg_cmd = [
103
  "ffmpeg", "-y", "-i", tmp_path,
104
  "-ar", str(TARGET_SR), "-ac", "1", converted
@@ -111,27 +116,21 @@ async def predict(file: UploadFile = File(...)):
111
  pass
112
 
113
  finally:
114
- # remove uploaded tmp file as soon as possible
115
  try:
116
  os.unlink(tmp_path)
117
  except Exception:
118
  pass
119
 
120
- # waveform_np shape: (n_samples,) or (n_samples, channels)
121
  if waveform_np.ndim > 1:
122
- # average channels to mono
123
  waveform_np = waveform_np.mean(axis=1)
124
 
125
- # Convert to torch tensor shape [1, n_samples]
126
  waveform = torch.tensor(waveform_np, dtype=torch.float32).unsqueeze(0)
127
 
128
- # Resample if necessary
129
  if sr != TARGET_SR:
130
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
131
  waveform = resampler(waveform)
132
  sr = TARGET_SR
133
 
134
- # Prepare inputs for HF model
135
  inputs = proc(
136
  waveform.squeeze().numpy(),
137
  sampling_rate=sr,
@@ -153,12 +152,11 @@ async def predict(file: UploadFile = File(...)):
153
  import traceback
154
  print("πŸ”₯ Error in /predict:", e)
155
  traceback.print_exc()
156
- # Return the error string (400) so client can see the reason
157
  return JSONResponse(status_code=400, content={"error": str(e)})
158
 
159
 
160
  if __name__ == "__main__":
161
- # Local dev fallback (Railway/Gunicorn uses CMD from Dockerfile)
162
  import uvicorn
163
  port = int(os.environ.get("PORT", 8000))
164
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
5
  from pathlib import Path
6
 
7
  import torch
 
8
  torch.set_num_threads(1)
9
 
10
  import torchaudio
 
24
  def get_model():
25
  """
26
  Lazily load processor and model on first call and cache them globally.
27
+ Uses a custom HF cache dir to avoid permission issues on Hugging Face Spaces.
28
  """
29
  global processor, model
30
  if processor is None or model is None:
31
+ print("πŸ” Loading HF processor & model (this may take 10–60s on first request)...")
32
  from transformers import Wav2Vec2Processor, AutoModelForAudioClassification
33
+
34
+ cache_dir = os.getenv("HF_HOME", "/app/hf_cache")
35
+
36
+ processor = Wav2Vec2Processor.from_pretrained(
37
+ "facebook/wav2vec2-base-960h",
38
+ cache_dir=cache_dir
39
+ )
40
  model = AutoModelForAudioClassification.from_pretrained(
41
+ "prithivMLmods/Common-Voice-Gender-Detection",
42
+ cache_dir=cache_dir
43
  )
44
  model.eval()
45
  print("βœ… Model & processor loaded.")
 
101
  try:
102
  waveform_np, sr = sf.read(tmp_path, dtype="float32")
103
  except Exception as e:
104
+ # If soundfile fails, convert with ffmpeg then read
105
  print("⚠️ soundfile could not read directly, trying ffmpeg conversion:", e)
106
  converted = tmp_path + ".converted.wav"
 
107
  ffmpeg_cmd = [
108
  "ffmpeg", "-y", "-i", tmp_path,
109
  "-ar", str(TARGET_SR), "-ac", "1", converted
 
116
  pass
117
 
118
  finally:
 
119
  try:
120
  os.unlink(tmp_path)
121
  except Exception:
122
  pass
123
 
 
124
  if waveform_np.ndim > 1:
 
125
  waveform_np = waveform_np.mean(axis=1)
126
 
 
127
  waveform = torch.tensor(waveform_np, dtype=torch.float32).unsqueeze(0)
128
 
 
129
  if sr != TARGET_SR:
130
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
131
  waveform = resampler(waveform)
132
  sr = TARGET_SR
133
 
 
134
  inputs = proc(
135
  waveform.squeeze().numpy(),
136
  sampling_rate=sr,
 
152
  import traceback
153
  print("πŸ”₯ Error in /predict:", e)
154
  traceback.print_exc()
 
155
  return JSONResponse(status_code=400, content={"error": str(e)})
156
 
157
 
158
  if __name__ == "__main__":
 
159
  import uvicorn
160
  port = int(os.environ.get("PORT", 8000))
161
+ print(f"πŸš€ Starting app on port {port}")
162
+ uvicorn.run(app, host="0.0.0.0", port=port)