Gamortsey commited on
Commit
ddb6bb8
·
verified ·
1 Parent(s): e6584a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -105
app.py CHANGED
@@ -16,12 +16,9 @@ from transformers import (
16
  )
17
 
18
  # ---------- Configuration ----------
19
- # Use smaller models suitable for CPU-only Hugging Face Spaces (free tier)
20
  WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small")
21
  NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M")
22
 
23
- # Map frontend language names -> (whisper_lang_arg, nllb_src_lang_tag)
24
- # Adjust tags if you have different NLLB language tags for specific dialects
25
  LANG_MAP = {
26
  "akan": (None, "aka_Latn"),
27
  "hausa": ("ha", "hau_Latn"),
@@ -31,13 +28,12 @@ LANG_MAP = {
31
  "english": ("en", None),
32
  }
33
 
34
- # Force CPU for free Spaces
35
- DEVICE = torch.device("cpu")
36
 
37
  app = Flask(__name__)
38
  CORS(app)
39
 
40
- # ---------- Model manager (lazy load) ----------
41
  class ModelManager:
42
  def __init__(self):
43
  self.whisper_processor = None
@@ -70,7 +66,6 @@ class ModelManager:
70
  if self.whisper_processor is None or self.whisper_model is None:
71
  raise RuntimeError("Whisper model not loaded")
72
 
73
- # Load audio and resample if needed
74
  waveform, sr = torchaudio.load(audio_path)
75
  if sr != 16000:
76
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
@@ -97,13 +92,11 @@ class ModelManager:
97
  if not src_text:
98
  return ""
99
  if not nllb_src_lang_tag:
100
- # Already English or no NLLB mapping — return source
101
  return src_text
102
 
103
  if self.nllb_tokenizer is None or self.nllb_model is None:
104
  raise RuntimeError("NLLB model not loaded")
105
 
106
- # Set tokenizer source lang if supported
107
  try:
108
  self.nllb_tokenizer.src_lang = nllb_src_lang_tag
109
  except Exception:
@@ -111,7 +104,6 @@ class ModelManager:
111
 
112
  inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)
113
 
114
- # Attempt to get forced BOS token id for English; fallback to no forced token
115
  forced_bos = None
116
  try:
117
  forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
@@ -129,7 +121,6 @@ class ModelManager:
129
 
130
  with torch.no_grad():
131
  translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs)
132
-
133
  translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
134
  return translated.strip()
135
 
@@ -138,13 +129,6 @@ model_manager = ModelManager()
138
  # ---------- REST endpoint ----------
139
  @app.route("/transcribe", methods=["POST"])
140
  def transcribe_endpoint():
141
- """
142
- POST multipart/form-data:
143
- - field 'audio': file (wav/mp3/ogg etc.)
144
- - field 'language': string key (akan, hausa, swahili, french, arabic, english)
145
- Response:
146
- - Plain text body with the translated text (Content-Type: text/plain)
147
- """
148
  if "audio" not in request.files:
149
  return Response("No audio file provided", status=400, mimetype="text/plain")
150
 
@@ -156,13 +140,11 @@ def transcribe_endpoint():
156
 
157
  whisper_lang_arg, nllb_src_tag = LANG_MAP[language]
158
 
159
- # Load models (lazy)
160
  try:
161
  model_manager.load()
162
  except Exception as e:
163
  return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")
164
 
165
- # Save audio to a temp file
166
  tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav")
167
  os.close(tmp_fd)
168
  audio_file.save(tmp_path)
@@ -170,7 +152,6 @@ def transcribe_endpoint():
170
  try:
171
  transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
172
  if not transcription:
173
- # nothing transcribed -> return empty body (204)
174
  return Response("", status=204, mimetype="text/plain")
175
 
176
  translation = model_manager.translate_to_english(transcription, nllb_src_tag)
@@ -183,102 +164,131 @@ def transcribe_endpoint():
183
  except Exception:
184
  pass
185
 
186
- # ---------- Robust Gradio UI mount (optional) ----------
187
  gradio_mounted = False
188
- try:
189
- import gradio as gr
190
- import soundfile as sf
191
- import numpy as np
192
-
193
- def _ui_transcribe(audio, language):
194
- """
195
- Accept many audio input shapes from different gradio versions:
196
- - filepath (str)
197
- - tuple (sr, ndarray)
198
- - ndarray (numpy)
199
- We normalize to a temporary wav file.
200
- """
201
- if audio is None:
202
- return "No audio", ""
203
-
204
- audio_path = None
205
- if isinstance(audio, str) and Path(audio).exists():
206
- audio_path = audio
207
- elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
208
- sr, data = audio[0], audio[1]
209
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
210
- sf.write(tmp.name, data, sr)
211
- audio_path = tmp.name
212
- elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"):
213
- sr = 16000
214
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
215
- sf.write(tmp.name, audio, sr)
216
- audio_path = tmp.name
217
- else:
 
 
218
  try:
219
- audio_path = getattr(audio, "name", None)
220
- except Exception:
221
- audio_path = None
 
 
 
 
 
 
 
 
222
 
223
- if not audio_path:
224
- return "Unsupported audio format from Gradio", ""
 
 
 
 
225
 
 
226
  try:
227
- model_manager.load()
228
- whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
229
- transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang)
230
- translation = model_manager.translate_to_english(transcription, nllb_tag)
231
- return transcription, translation
232
- finally:
233
- # try cleanup
234
- try:
235
- if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path):
236
- os.remove(audio_path)
237
- except Exception:
238
- pass
239
 
240
- demo = None
241
- try:
242
- # modern API
243
- audio_component = gr.Audio(source="microphone", type="filepath")
244
- dropdown = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
245
- demo = gr.Interface(
246
- fn=_ui_transcribe,
247
- inputs=[audio_component, dropdown],
248
- outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Translation (English)")],
249
- title="Multilingual Transcriber (server)"
250
- )
251
- except TypeError:
252
- # fallback for older gradio versions
253
  try:
254
- audio_component = gr.inputs.Audio(source="microphone", type="filepath")
255
- dropdown = gr.inputs.Dropdown(choices=list(LANG_MAP.keys()), default="english")
256
- outputs = [gr.outputs.Textbox(), gr.outputs.Textbox()]
257
- demo = gr.Interface(fn=_ui_transcribe, inputs=[audio_component, dropdown], outputs=outputs,
258
- title="Multilingual Transcriber (server)")
259
- except Exception as e:
260
- print("Gradio fallback constructor failed:", e)
261
- demo = None
262
- except Exception as e:
263
- print("Gradio constructor failed:", e)
264
- demo = None
265
 
266
- if demo is not None:
267
  try:
268
- app = gr.mount_gradio_app(app, demo, path="/ui")
269
- gradio_mounted = True
270
- print("Gradio mounted at /ui")
271
- except Exception as e:
272
- print("Failed to mount Gradio app:", e)
273
- gradio_mounted = False
274
- else:
275
- print("Gradio demo not created; continuing without mounted UI.")
276
 
277
- except Exception as e:
278
- print("Gradio UI unavailable or failed to mount:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  gradio_mounted = False
280
 
281
- # Root endpoint: redirect to /ui if mounted, otherwise status text
282
  @app.route("/")
283
  def index():
284
  if gradio_mounted:
 
16
  )
17
 
18
  # ---------- Configuration ----------
 
19
  WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small")
20
  NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M")
21
 
 
 
22
  LANG_MAP = {
23
  "akan": (None, "aka_Latn"),
24
  "hausa": ("ha", "hau_Latn"),
 
28
  "english": ("en", None),
29
  }
30
 
31
+ DEVICE = torch.device("cpu") # Free HF Spaces = CPU
 
32
 
33
  app = Flask(__name__)
34
  CORS(app)
35
 
36
+ # ---------- Model manager ----------
37
  class ModelManager:
38
  def __init__(self):
39
  self.whisper_processor = None
 
66
  if self.whisper_processor is None or self.whisper_model is None:
67
  raise RuntimeError("Whisper model not loaded")
68
 
 
69
  waveform, sr = torchaudio.load(audio_path)
70
  if sr != 16000:
71
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
 
92
  if not src_text:
93
  return ""
94
  if not nllb_src_lang_tag:
 
95
  return src_text
96
 
97
  if self.nllb_tokenizer is None or self.nllb_model is None:
98
  raise RuntimeError("NLLB model not loaded")
99
 
 
100
  try:
101
  self.nllb_tokenizer.src_lang = nllb_src_lang_tag
102
  except Exception:
 
104
 
105
  inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)
106
 
 
107
  forced_bos = None
108
  try:
109
  forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
 
121
 
122
  with torch.no_grad():
123
  translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs)
 
124
  translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
125
  return translated.strip()
126
 
 
129
  # ---------- REST endpoint ----------
130
  @app.route("/transcribe", methods=["POST"])
131
  def transcribe_endpoint():
 
 
 
 
 
 
 
132
  if "audio" not in request.files:
133
  return Response("No audio file provided", status=400, mimetype="text/plain")
134
 
 
140
 
141
  whisper_lang_arg, nllb_src_tag = LANG_MAP[language]
142
 
 
143
  try:
144
  model_manager.load()
145
  except Exception as e:
146
  return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")
147
 
 
148
  tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav")
149
  os.close(tmp_fd)
150
  audio_file.save(tmp_path)
 
152
  try:
153
  transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
154
  if not transcription:
 
155
  return Response("", status=204, mimetype="text/plain")
156
 
157
  translation = model_manager.translate_to_english(transcription, nllb_src_tag)
 
164
  except Exception:
165
  pass
166
 
167
+ # ---------- Robust Gradio UI mount ----------
168
  gradio_mounted = False
169
+ if os.environ.get("DISABLE_GRADIO", "0") != "1":
170
+ try:
171
+ import gradio as gr
172
+ import soundfile as sf
173
+ import numpy as np
174
+
175
+ def _ui_transcribe(audio, language):
176
+ if audio is None:
177
+ return "No audio", ""
178
+
179
+ audio_path = None
180
+ if isinstance(audio, str) and Path(audio).exists():
181
+ audio_path = audio
182
+ elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
183
+ sr, data = audio[0], audio[1]
184
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
185
+ sf.write(tmp.name, data, sr)
186
+ audio_path = tmp.name
187
+ elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"):
188
+ sr = 16000
189
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
190
+ sf.write(tmp.name, audio, sr)
191
+ audio_path = tmp.name
192
+ else:
193
+ try:
194
+ audio_path = getattr(audio, "name", None)
195
+ except Exception:
196
+ audio_path = None
197
+
198
+ if not audio_path:
199
+ return "Unsupported audio format from Gradio", ""
200
+
201
  try:
202
+ model_manager.load()
203
+ whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
204
+ transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang)
205
+ translation = model_manager.translate_to_english(transcription, nllb_tag)
206
+ return transcription, translation
207
+ finally:
208
+ try:
209
+ if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path):
210
+ os.remove(audio_path)
211
+ except Exception:
212
+ pass
213
 
214
+ demo = None
215
+ # Create components robustly across gradio versions
216
+ audio_component = None
217
+ dropdown_component = None
218
+ textbox_out1 = None
219
+ textbox_out2 = None
220
 
221
+ # Option A: modern simple API (gr.Audio)
222
  try:
223
+ if hasattr(gr, "Audio"):
224
+ audio_component = gr.Audio(source="microphone", type="filepath")
225
+ elif hasattr(gr, "components") and hasattr(gr.components, "Audio"):
226
+ audio_component = gr.components.Audio(source="microphone", type="filepath")
227
+ except Exception:
228
+ audio_component = None
 
 
 
 
 
 
229
 
230
+ # Dropdown
 
 
 
 
 
 
 
 
 
 
 
 
231
  try:
232
+ if hasattr(gr, "Dropdown"):
233
+ dropdown_component = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
234
+ elif hasattr(gr, "components") and hasattr(gr.components, "Dropdown"):
235
+ dropdown_component = gr.components.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
236
+ except Exception:
237
+ dropdown_component = None
 
 
 
 
 
238
 
239
+ # Output textboxes
240
  try:
241
+ if hasattr(gr, "Textbox"):
242
+ textbox_out1 = gr.Textbox(label="Transcription")
243
+ textbox_out2 = gr.Textbox(label="Translation (English)")
244
+ elif hasattr(gr, "components") and hasattr(gr.components, "Textbox"):
245
+ textbox_out1 = gr.components.Textbox(label="Transcription")
246
+ textbox_out2 = gr.components.Textbox(label="Translation (English)")
247
+ except Exception:
248
+ textbox_out1 = textbox_out2 = None
249
 
250
+ # If any component missing, try old 'inputs/outputs' API as final fallback
251
+ if audio_component is None or dropdown_component is None or textbox_out1 is None:
252
+ try:
253
+ if hasattr(gr, "inputs") and hasattr(gr, "inputs",):
254
+ audio_component = getattr(gr.inputs, "Audio")(source="microphone", type="filepath")
255
+ dropdown_component = getattr(gr.inputs, "Dropdown")(choices=list(LANG_MAP.keys()), default="english")
256
+ textbox_out1 = getattr(gr.outputs, "Textbox")()
257
+ textbox_out2 = getattr(gr.outputs, "Textbox")()
258
+ except Exception:
259
+ pass
260
+
261
+ # If we have required components, create the Interface
262
+ if audio_component is not None and dropdown_component is not None and textbox_out1 is not None:
263
+ try:
264
+ demo = gr.Interface(
265
+ fn=_ui_transcribe,
266
+ inputs=[audio_component, dropdown_component],
267
+ outputs=[textbox_out1, textbox_out2],
268
+ title="Multilingual Transcriber (server)"
269
+ )
270
+ except Exception as e:
271
+ print("Failed to create gr.Interface:", e)
272
+ demo = None
273
+
274
+ if demo is not None:
275
+ try:
276
+ app = gr.mount_gradio_app(app, demo, path="/ui")
277
+ gradio_mounted = True
278
+ print("Gradio mounted at /ui")
279
+ except Exception as e:
280
+ print("Failed to mount Gradio app:", e)
281
+ gradio_mounted = False
282
+ else:
283
+ print("Gradio demo not created; continuing without mounted UI.")
284
+ except Exception as e:
285
+ print("Gradio UI unavailable or failed to mount:", e)
286
+ gradio_mounted = False
287
+ else:
288
+ print("Gradio mounting disabled via DISABLE_GRADIO=1")
289
  gradio_mounted = False
290
 
291
+ # Root endpoint
292
  @app.route("/")
293
  def index():
294
  if gradio_mounted: