Gamortsey commited on
Commit
e6584a7
·
verified ·
1 Parent(s): f6b0045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -73
app.py CHANGED
@@ -1,10 +1,13 @@
1
  # app.py
2
  import os
3
  import tempfile
4
- from flask import Flask, request, Response, jsonify
 
5
  from flask_cors import CORS
6
  import torch
7
  import torchaudio
 
 
8
  from transformers import (
9
  AutoProcessor,
10
  AutoModelForSpeechSeq2Seq,
@@ -13,27 +16,28 @@ from transformers import (
13
  )
14
 
15
  # ---------- Configuration ----------
16
- # Use small CPU-friendly models for free HF Spaces
17
- WHISPER_MODEL = "openai/whisper-small"
18
- NLLB_MODEL = "facebook/nllb-200-distilled-600M"
19
 
20
- # Map frontend language names -> (whisper_lang_code, nllb_src_code)
 
21
  LANG_MAP = {
22
- # language_key: (whisper_language_arg, nllb_src_lang_tag)
23
- "akan": (None, "aka_Latn"), # if you have a specialized Akan whisper model, change whisper arg
24
  "hausa": ("ha", "hau_Latn"),
25
  "swahili": ("sw", "swh_Latn"),
26
  "french": ("fr", "fra_Latn"),
27
- "arabic": ("ar", "arb_Arab"), # nllb code for Arabic variants may vary (this is illustrative)
28
  "english": ("en", None),
29
  }
30
 
31
- DEVICE = torch.device("cpu") # Free Spaces = CPU-only
 
32
 
33
  app = Flask(__name__)
34
  CORS(app)
35
 
36
- # ---------- Lazy model manager ----------
37
  class ModelManager:
38
  def __init__(self):
39
  self.whisper_processor = None
@@ -45,26 +49,28 @@ class ModelManager:
45
  def load(self):
46
  if self._loaded:
47
  return
48
- # Whisper processor & model
49
- print("Loading Whisper processor/model (small)...")
50
- self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL)
51
- self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
52
- WHISPER_MODEL
53
- ).to(DEVICE)
54
 
55
- # NLLB tokenizer & model (600M)
56
- print("Loading NLLB tokenizer/model (600M)...")
57
- self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
58
- self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE)
 
 
59
 
60
  self._loaded = True
61
- print("Models loaded.")
62
 
63
  def transcribe(self, audio_path, whisper_language_arg=None):
64
- # loads and runs whisper-small to produce transcription string
65
  if self.whisper_processor is None or self.whisper_model is None:
66
- raise RuntimeError("Whisper not loaded")
67
 
 
68
  waveform, sr = torchaudio.load(audio_path)
69
  if sr != 16000:
70
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
@@ -88,33 +94,44 @@ class ModelManager:
88
  return decoded[0].strip()
89
 
90
  def translate_to_english(self, src_text, nllb_src_lang_tag):
91
- # returns english translation string using nllb
 
92
  if not nllb_src_lang_tag:
93
- # if src already english or no mapping, return original
94
  return src_text
 
95
  if self.nllb_tokenizer is None or self.nllb_model is None:
96
- raise RuntimeError("NLLB not loaded")
97
 
98
- # set src_lang on tokenizer (some NLLB tokenizers use this attribute)
99
  try:
100
  self.nllb_tokenizer.src_lang = nllb_src_lang_tag
101
  except Exception:
102
  pass
103
 
104
  inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)
105
- forced_bos_token_id = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  with torch.no_grad():
108
- translated_tokens = self.nllb_model.generate(
109
- **inputs,
110
- forced_bos_token_id=forced_bos_token_id,
111
- max_length=512,
112
- num_beams=4,
113
- no_repeat_ngram_size=2,
114
- early_stopping=True
115
- )
116
- out = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
117
- return out.strip()
118
 
119
  model_manager = ModelManager()
120
 
@@ -122,10 +139,10 @@ model_manager = ModelManager()
122
  @app.route("/transcribe", methods=["POST"])
123
  def transcribe_endpoint():
124
  """
125
- Accepts:
126
- - multipart form file field: 'audio' (wav/mp3 etc.)
127
- - form field 'language' (one of keys in LANG_MAP: akan, hausa, swahili, french, arabic, english)
128
- Returns:
129
  - Plain text body with the translated text (Content-Type: text/plain)
130
  """
131
  if "audio" not in request.files:
@@ -133,32 +150,30 @@ def transcribe_endpoint():
133
 
134
  audio_file = request.files["audio"]
135
  language = (request.form.get("language") or request.args.get("language") or "english").lower()
 
136
  if language not in LANG_MAP:
137
  return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain")
138
 
139
  whisper_lang_arg, nllb_src_tag = LANG_MAP[language]
140
 
141
- # Ensure models loaded lazily (first request)
142
  try:
143
  model_manager.load()
144
  except Exception as e:
145
  return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")
146
 
147
- # Save audio to temp file
148
- tmp_fd, tmp_path = tempfile.mkstemp(suffix=os.path.splitext(audio_file.filename)[1] or ".wav")
149
  os.close(tmp_fd)
150
  audio_file.save(tmp_path)
151
 
152
  try:
153
- # Transcribe (may be slow on CPU)
154
  transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
155
  if not transcription:
156
- return Response("", status=204, mimetype="text/plain") # empty body if nothing transcribed
 
157
 
158
- # Translate to English (if applicable)
159
  translation = model_manager.translate_to_english(transcription, nllb_src_tag)
160
-
161
- # Return only the translated text (plain text)
162
  return Response(translation, status=200, mimetype="text/plain")
163
  except Exception as e:
164
  return Response(f"Processing failed: {e}", status=500, mimetype="text/plain")
@@ -168,37 +183,108 @@ def transcribe_endpoint():
168
  except Exception:
169
  pass
170
 
171
- # Optional: a tiny Gradio UI for testing (mounted)
 
172
  try:
173
  import gradio as gr
 
 
174
 
175
  def _ui_transcribe(audio, language):
176
- # audio comes as file path from gradio
 
 
 
 
 
 
177
  if audio is None:
178
  return "No audio", ""
179
- # call local endpoint function for consistent behavior
180
- whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
181
- model_manager.load()
182
- trans = model_manager.transcribe(audio, whisper_lang)
183
- trans_en = model_manager.translate_to_english(trans, nllb_tag)
184
- return trans, trans_en
185
-
186
- demo = gr.Interface(
187
- fn=_ui_transcribe,
188
- inputs=[
189
- gr.Audio(source="microphone", type="filepath"),
190
- gr.Dropdown(choices=list(LANG_MAP.keys()), label="Language", value="english")
191
- ],
192
- outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Translation (English)")],
193
- title="Multilingual Transcriber (server)"
194
- )
195
- # mount gradio app under /ui so the REST API remains at /transcribe
196
- from gradio.routes import MountableApp
197
- app = gr.mount_gradio_app(app, demo, path="/ui")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
  print("Gradio UI unavailable or failed to mount:", e)
 
200
 
 
 
 
 
 
 
201
 
202
  if __name__ == "__main__":
203
- # For local debug only
204
- app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
 
1
  # app.py
2
  import os
3
  import tempfile
4
+ from pathlib import Path
5
+ from flask import Flask, request, Response, redirect
6
  from flask_cors import CORS
7
  import torch
8
  import torchaudio
9
+
10
+ # Transformers imports (lazy loaded in ModelManager.load to reduce startup overhead)
11
  from transformers import (
12
  AutoProcessor,
13
  AutoModelForSpeechSeq2Seq,
 
16
  )
17
 
18
  # ---------- Configuration ----------
19
+ # Use smaller models suitable for CPU-only Hugging Face Spaces (free tier)
20
+ WHISPER_MODEL = os.environ.get("WHISPER_MODEL", "openai/whisper-small")
21
+ NLLB_MODEL = os.environ.get("NLLB_MODEL", "facebook/nllb-200-distilled-600M")
22
 
23
+ # Map frontend language names -> (whisper_lang_arg, nllb_src_lang_tag)
24
+ # Adjust tags if you have different NLLB language tags for specific dialects
25
  LANG_MAP = {
26
+ "akan": (None, "aka_Latn"),
 
27
  "hausa": ("ha", "hau_Latn"),
28
  "swahili": ("sw", "swh_Latn"),
29
  "french": ("fr", "fra_Latn"),
30
+ "arabic": ("ar", "arb_Arab"),
31
  "english": ("en", None),
32
  }
33
 
34
+ # Force CPU for free Spaces
35
+ DEVICE = torch.device("cpu")
36
 
37
  app = Flask(__name__)
38
  CORS(app)
39
 
40
+ # ---------- Model manager (lazy load) ----------
41
  class ModelManager:
42
  def __init__(self):
43
  self.whisper_processor = None
 
49
  def load(self):
50
  if self._loaded:
51
  return
52
+ print(f"Loading Whisper model: {WHISPER_MODEL}")
53
+ try:
54
+ self.whisper_processor = AutoProcessor.from_pretrained(WHISPER_MODEL)
55
+ self.whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(WHISPER_MODEL).to(DEVICE)
56
+ except Exception as e:
57
+ raise RuntimeError(f"Failed to load Whisper model ({WHISPER_MODEL}): {e}")
58
 
59
+ print(f"Loading NLLB tokenizer/model: {NLLB_MODEL}")
60
+ try:
61
+ self.nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
62
+ self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL).to(DEVICE)
63
+ except Exception as e:
64
+ raise RuntimeError(f"Failed to load NLLB model ({NLLB_MODEL}): {e}")
65
 
66
  self._loaded = True
67
+ print("Models loaded successfully.")
68
 
69
  def transcribe(self, audio_path, whisper_language_arg=None):
 
70
  if self.whisper_processor is None or self.whisper_model is None:
71
+ raise RuntimeError("Whisper model not loaded")
72
 
73
+ # Load audio and resample if needed
74
  waveform, sr = torchaudio.load(audio_path)
75
  if sr != 16000:
76
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
 
94
  return decoded[0].strip()
95
 
96
  def translate_to_english(self, src_text, nllb_src_lang_tag):
97
+ if not src_text:
98
+ return ""
99
  if not nllb_src_lang_tag:
100
+ # Already English or no NLLB mapping return source
101
  return src_text
102
+
103
  if self.nllb_tokenizer is None or self.nllb_model is None:
104
+ raise RuntimeError("NLLB model not loaded")
105
 
106
+ # Set tokenizer source lang if supported
107
  try:
108
  self.nllb_tokenizer.src_lang = nllb_src_lang_tag
109
  except Exception:
110
  pass
111
 
112
  inputs = self.nllb_tokenizer(src_text, return_tensors="pt").to(DEVICE)
113
+
114
+ # Attempt to get forced BOS token id for English; fallback to no forced token
115
+ forced_bos = None
116
+ try:
117
+ forced_bos = self.nllb_tokenizer.convert_tokens_to_ids("eng_Latn")
118
+ except Exception:
119
+ forced_bos = None
120
+
121
+ gen_kwargs = {
122
+ "max_length": 512,
123
+ "num_beams": 4,
124
+ "no_repeat_ngram_size": 2,
125
+ "early_stopping": True
126
+ }
127
+ if forced_bos is not None:
128
+ gen_kwargs["forced_bos_token_id"] = forced_bos
129
 
130
  with torch.no_grad():
131
+ translated_tokens = self.nllb_model.generate(**inputs, **gen_kwargs)
132
+
133
+ translated = self.nllb_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
134
+ return translated.strip()
 
 
 
 
 
 
135
 
136
  model_manager = ModelManager()
137
 
 
139
  @app.route("/transcribe", methods=["POST"])
140
  def transcribe_endpoint():
141
  """
142
+ POST multipart/form-data:
143
+ - field 'audio': file (wav/mp3/ogg etc.)
144
+ - field 'language': string key (akan, hausa, swahili, french, arabic, english)
145
+ Response:
146
  - Plain text body with the translated text (Content-Type: text/plain)
147
  """
148
  if "audio" not in request.files:
 
150
 
151
  audio_file = request.files["audio"]
152
  language = (request.form.get("language") or request.args.get("language") or "english").lower()
153
+
154
  if language not in LANG_MAP:
155
  return Response(f"Unsupported language: {language}", status=400, mimetype="text/plain")
156
 
157
  whisper_lang_arg, nllb_src_tag = LANG_MAP[language]
158
 
159
+ # Load models (lazy)
160
  try:
161
  model_manager.load()
162
  except Exception as e:
163
  return Response(f"Model loading failed: {e}", status=500, mimetype="text/plain")
164
 
165
+ # Save audio to a temp file
166
+ tmp_fd, tmp_path = tempfile.mkstemp(suffix=Path(audio_file.filename).suffix or ".wav")
167
  os.close(tmp_fd)
168
  audio_file.save(tmp_path)
169
 
170
  try:
 
171
  transcription = model_manager.transcribe(tmp_path, whisper_language_arg=whisper_lang_arg)
172
  if not transcription:
173
+ # nothing transcribed -> return empty body (204)
174
+ return Response("", status=204, mimetype="text/plain")
175
 
 
176
  translation = model_manager.translate_to_english(transcription, nllb_src_tag)
 
 
177
  return Response(translation, status=200, mimetype="text/plain")
178
  except Exception as e:
179
  return Response(f"Processing failed: {e}", status=500, mimetype="text/plain")
 
183
  except Exception:
184
  pass
185
 
186
+ # ---------- Robust Gradio UI mount (optional) ----------
187
+ gradio_mounted = False
188
  try:
189
  import gradio as gr
190
+ import soundfile as sf
191
+ import numpy as np
192
 
193
  def _ui_transcribe(audio, language):
194
+ """
195
+ Accept many audio input shapes from different gradio versions:
196
+ - filepath (str)
197
+ - tuple (sr, ndarray)
198
+ - ndarray (numpy)
199
+ We normalize to a temporary wav file.
200
+ """
201
  if audio is None:
202
  return "No audio", ""
203
+
204
+ audio_path = None
205
+ if isinstance(audio, str) and Path(audio).exists():
206
+ audio_path = audio
207
+ elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
208
+ sr, data = audio[0], audio[1]
209
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
210
+ sf.write(tmp.name, data, sr)
211
+ audio_path = tmp.name
212
+ elif isinstance(audio, (np.ndarray,)) or hasattr(audio, "shape"):
213
+ sr = 16000
214
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
215
+ sf.write(tmp.name, audio, sr)
216
+ audio_path = tmp.name
217
+ else:
218
+ try:
219
+ audio_path = getattr(audio, "name", None)
220
+ except Exception:
221
+ audio_path = None
222
+
223
+ if not audio_path:
224
+ return "Unsupported audio format from Gradio", ""
225
+
226
+ try:
227
+ model_manager.load()
228
+ whisper_lang, nllb_tag = LANG_MAP.get(language.lower(), (None, None))
229
+ transcription = model_manager.transcribe(audio_path, whisper_language_arg=whisper_lang)
230
+ translation = model_manager.translate_to_english(transcription, nllb_tag)
231
+ return transcription, translation
232
+ finally:
233
+ # try cleanup
234
+ try:
235
+ if audio_path and Path(audio_path).exists() and "/tmp" in str(audio_path):
236
+ os.remove(audio_path)
237
+ except Exception:
238
+ pass
239
+
240
+ demo = None
241
+ try:
242
+ # modern API
243
+ audio_component = gr.Audio(source="microphone", type="filepath")
244
+ dropdown = gr.Dropdown(choices=list(LANG_MAP.keys()), value="english", label="Language")
245
+ demo = gr.Interface(
246
+ fn=_ui_transcribe,
247
+ inputs=[audio_component, dropdown],
248
+ outputs=[gr.Textbox(label="Transcription"), gr.Textbox(label="Translation (English)")],
249
+ title="Multilingual Transcriber (server)"
250
+ )
251
+ except TypeError:
252
+ # fallback for older gradio versions
253
+ try:
254
+ audio_component = gr.inputs.Audio(source="microphone", type="filepath")
255
+ dropdown = gr.inputs.Dropdown(choices=list(LANG_MAP.keys()), default="english")
256
+ outputs = [gr.outputs.Textbox(), gr.outputs.Textbox()]
257
+ demo = gr.Interface(fn=_ui_transcribe, inputs=[audio_component, dropdown], outputs=outputs,
258
+ title="Multilingual Transcriber (server)")
259
+ except Exception as e:
260
+ print("Gradio fallback constructor failed:", e)
261
+ demo = None
262
+ except Exception as e:
263
+ print("Gradio constructor failed:", e)
264
+ demo = None
265
+
266
+ if demo is not None:
267
+ try:
268
+ app = gr.mount_gradio_app(app, demo, path="/ui")
269
+ gradio_mounted = True
270
+ print("Gradio mounted at /ui")
271
+ except Exception as e:
272
+ print("Failed to mount Gradio app:", e)
273
+ gradio_mounted = False
274
+ else:
275
+ print("Gradio demo not created; continuing without mounted UI.")
276
+
277
  except Exception as e:
278
  print("Gradio UI unavailable or failed to mount:", e)
279
+ gradio_mounted = False
280
 
281
+ # Root endpoint: redirect to /ui if mounted, otherwise status text
282
+ @app.route("/")
283
+ def index():
284
+ if gradio_mounted:
285
+ return redirect("/ui")
286
+ return Response("Server running. REST endpoint available at /transcribe", status=200, mimetype="text/plain")
287
 
288
  if __name__ == "__main__":
289
+ port = int(os.environ.get("PORT", 7860))
290
+ app.run(host="0.0.0.0", port=port, debug=False)