nexusbert commited on
Commit
fc4a5de
·
verified ·
1 Parent(s): cd0d2d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -103
app.py CHANGED
@@ -21,10 +21,24 @@ nest_asyncio.apply()
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @asynccontextmanager
25
  async def lifespan(app: FastAPI):
26
- load_models()
 
27
  yield
 
28
 
29
  app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan)
30
 
@@ -36,55 +50,50 @@ app.add_middleware(
36
  allow_headers=["*"],
37
  )
38
 
39
-
40
- ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask"
41
- tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
42
- natlas_tokenizer, natlas_model = None, None
43
-
44
- asr_models = {
45
- "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
46
- "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
47
- "ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None},
48
- "en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None},
49
- }
50
 
51
  def load_models():
52
- global tts_ha, tts_en, tts_yo, tts_ig, natlas_tokenizer, natlas_model
53
  device = 0 if torch.cuda.is_available() else -1
54
- hf_token = os.getenv("HF_TOKEN")
55
- if hf_token:
56
- hf_token = hf_token.strip()
57
- if not hf_token:
58
- logger.warning("HF_TOKEN not set! This may cause authentication failures for gated repositories.")
59
- logger.warning("Please set HF_TOKEN environment variable to access restricted models.")
60
- else:
61
- logger.info("HF_TOKEN is set and ready for authenticated model access.")
62
- logger.info("Loading TTS models...")
63
- try:
64
- tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device)
65
- logger.info("Loaded TTS (Hausa)")
66
- except Exception as e:
67
- logger.exception("Failed to load TTS (Hausa)")
68
- tts_ha = None
69
- try:
70
- tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device)
71
- logger.info("Loaded TTS (English)")
72
- except Exception:
73
- logger.exception("Failed to load TTS (English)")
74
- tts_en = None
75
- try:
76
- tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device)
77
- logger.info("Loaded TTS (Yoruba)")
78
- except Exception:
79
- logger.exception("Failed to load TTS (Yoruba)")
80
- tts_yo = None
81
-
82
- tts_ig = None
83
- logger.info("Igbo TTS model disabled - will return text responses for Igbo language")
84
 
85
- logger.info("N-ATLaS language identification model will be lazy-loaded on first use")
86
-
87
- logger.info("Deferred ASR model loads: will lazy-load per language on first use")
88
 
89
  def _get_asr(lang_code: str):
90
  entry = asr_models.get(lang_code)
@@ -93,9 +102,7 @@ def _get_asr(lang_code: str):
93
  if entry["model"] is not None and entry["proc"] is not None:
94
  return entry["model"], entry["proc"]
95
  repo_id = entry["repo"]
96
- hf_token = os.getenv("HF_TOKEN")
97
- if hf_token:
98
- hf_token = hf_token.strip()
99
  try:
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
  logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...")
@@ -162,17 +169,14 @@ def speech_to_text(audio_data: bytes) -> str:
162
  text = _run_whisper(model, proc, audio_array)
163
  if text:
164
  candidates.append((code, text))
165
-
166
  for lang_code, text in candidates:
167
  det = detect_language(text)
168
  if lang_code == det:
169
  return text
170
-
171
  if candidates:
172
  return max((t for _, t in candidates), key=lambda s: len(s or ""))
173
  return ""
174
 
175
-
176
  def get_ai_response(text: str) -> str:
177
  try:
178
  response = requests.post(ASK_URL, json={"query": text}, timeout=30)
@@ -200,19 +204,12 @@ def _load_natlas():
200
  global natlas_tokenizer, natlas_model
201
  if natlas_tokenizer is not None and natlas_model is not None:
202
  return True
203
-
204
- hf_token = os.getenv("HF_TOKEN")
205
- if hf_token:
206
- hf_token = hf_token.strip()
207
-
208
  if not hf_token:
209
  logger.error("HF_TOKEN not available for N-ATLaS model access")
210
  return False
211
-
212
  try:
213
- logger.info("Lazy-loading N-ATLaS language identification model...")
214
- logger.info("This may take a few minutes as the model loads its shards...")
215
-
216
  natlas_tokenizer = AutoTokenizer.from_pretrained("NCAIR1/N-ATLaS", token=hf_token)
217
  natlas_model = AutoModelForCausalLM.from_pretrained(
218
  "NCAIR1/N-ATLaS",
@@ -232,40 +229,26 @@ def _load_natlas():
232
 
233
  def detect_language(text: str) -> str:
234
  logger.info(f"Detecting language for text: '{text[:50]}...'")
235
-
236
  if not _load_natlas():
237
  logger.warning("N-ATLaS model not available, falling back to keyword detection")
238
  text_lower = text.lower()
239
  if any(word in text_lower for word in HAUSA_WORDS):
240
- logger.info("Keyword detection: Hausa")
241
  return "ha"
242
  elif any(word in text_lower for word in YORUBA_WORDS):
243
- logger.info("Keyword detection: Yoruba")
244
  return "yo"
245
  elif any(word in text_lower for word in IGBO_WORDS):
246
- logger.info("Keyword detection: Igbo")
247
  return "ig"
248
  else:
249
- logger.info("Keyword detection: English (default)")
250
  return "en"
251
-
252
  try:
253
- logger.info("Using N-ATLaS for language detection")
254
  messages = [
255
  {'role': 'system', 'content': 'You are a language identification assistant. Identify the language of the given text and respond with only the language code: "en" for English, "ha" for Hausa, "yo" for Yoruba, or "ig" for Igbo.'},
256
  {'role': 'user', 'content': f'What language is this text written in? "{text}"'}
257
  ]
258
-
259
- formatted_text = natlas_tokenizer.apply_chat_template(
260
- messages,
261
- add_generation_prompt=True,
262
- tokenize=False
263
- )
264
-
265
  input_tokens = natlas_tokenizer(formatted_text, return_tensors='pt', add_special_tokens=False)
266
  if torch.cuda.is_available():
267
  input_tokens = input_tokens.to('cuda')
268
-
269
  with torch.no_grad():
270
  outputs = natlas_model.generate(
271
  **input_tokens,
@@ -275,28 +258,18 @@ def detect_language(text: str) -> str:
275
  temperature=0.1,
276
  do_sample=False
277
  )
278
-
279
  response = natlas_tokenizer.batch_decode(outputs)[0]
280
  response_text = response.split(messages[1]['content'])[-1].strip().lower()
281
-
282
- logger.info(f"N-ATLaS response: '{response_text}'")
283
-
284
  if 'ha' in response_text:
285
- logger.info("N-ATLaS detection: Hausa")
286
  return "ha"
287
  elif 'yo' in response_text:
288
- logger.info("N-ATLaS detection: Yoruba")
289
  return "yo"
290
  elif 'ig' in response_text:
291
- logger.info("N-ATLaS detection: Igbo")
292
  return "ig"
293
  else:
294
- logger.info("N-ATLaS detection: English (default)")
295
  return "en"
296
-
297
  except Exception as e:
298
  logger.exception(f"Language detection failed: {e}")
299
- logger.warning("Falling back to keyword detection due to N-ATLaS error")
300
  text_lower = text.lower()
301
  if any(word in text_lower for word in HAUSA_WORDS):
302
  return "ha"
@@ -308,9 +281,8 @@ def detect_language(text: str) -> str:
308
  return "en"
309
 
310
  def text_to_speech_file(text: str) -> str:
 
311
  lang = detect_language(text)
312
- print(f"Detected language: {lang}")
313
-
314
  if lang == "ig":
315
  logger.info("Igbo language detected - returning text response instead of audio")
316
  fd, path = tempfile.mkstemp(suffix=".txt")
@@ -324,38 +296,25 @@ def text_to_speech_file(text: str) -> str:
324
  tts_model = tts_yo
325
  else:
326
  tts_model = tts_en
327
-
328
  if tts_model is None:
329
- logger.error(f"TTS model for {lang} is not available")
330
  raise HTTPException(status_code=500, detail=f"TTS model for {lang} is not available")
331
-
332
  speech_output = tts_model(text)
333
  audio_raw = speech_output["audio"]
334
- sampling_rate = int(speech_output["sampling_rate"])
335
-
336
-
337
  if isinstance(audio_raw, torch.Tensor):
338
  audio_np = audio_raw.detach().cpu().numpy()
339
  else:
340
  audio_np = np.asarray(audio_raw)
341
-
342
  if audio_np.ndim > 1:
343
  audio_np = audio_np.reshape(-1)
344
  audio_np = audio_np.astype(np.float32, copy=False)
345
-
346
-
347
  audio_clipped = np.clip(audio_np, -1.0, 1.0)
348
  audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
349
-
350
-
351
  fd, path = tempfile.mkstemp(suffix=".wav")
352
  os.close(fd)
353
-
354
-
355
  sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16')
356
  return path
357
 
358
-
359
  @app.get("/")
360
  async def root():
361
  return {"status": "ok", "message": "System ready"}
@@ -412,5 +371,4 @@ async def speak_to_ai(audio_file: UploadFile = File(...), speak: bool = True):
412
  return {"transcription": transcription, "ai_response": ai_response}
413
 
414
  if __name__ == "__main__":
415
- import uvicorn
416
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
25
+ natlas_tokenizer, natlas_model = None, None
26
+
27
+ ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask"
28
+
29
+ asr_models = {
30
+ "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
31
+ "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
32
+ "ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None},
33
+ "en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None},
34
+ }
35
+
36
  @asynccontextmanager
37
  async def lifespan(app: FastAPI):
38
+ logger.info("Starting Farmlingua AI Speech Interface...")
39
+ preload_natlas()
40
  yield
41
+ logger.info("Shutting down Farmlingua service...")
42
 
43
  app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan)
44
 
 
50
  allow_headers=["*"],
51
  )
52
 
53
+ def preload_natlas():
54
+ global natlas_tokenizer, natlas_model
55
+ if natlas_tokenizer is not None and natlas_model is not None:
56
+ logger.info("N-ATLaS already loaded.")
57
+ return
58
+ success = _load_natlas()
59
+ if success:
60
+ logger.info("N-ATLaS successfully preloaded at startup.")
61
+ else:
62
+ logger.warning("N-ATLaS preload failed. It will retry on first use.")
 
63
 
64
  def load_models():
65
+ global tts_ha, tts_en, tts_yo, tts_ig
66
  device = 0 if torch.cuda.is_available() else -1
67
+ hf_token = os.getenv("HF_TOKEN", "").strip()
68
+
69
+ logger.info("Lazy-loading TTS models on first use...")
70
+
71
+ if tts_ha is None:
72
+ try:
73
+ tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device, token=hf_token)
74
+ logger.info("Loaded TTS (Hausa)")
75
+ except Exception:
76
+ logger.exception("Failed to load TTS (Hausa)")
77
+ tts_ha = None
78
+
79
+ if tts_en is None:
80
+ try:
81
+ tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device, token=hf_token)
82
+ logger.info("Loaded TTS (English)")
83
+ except Exception:
84
+ logger.exception("Failed to load TTS (English)")
85
+ tts_en = None
86
+
87
+ if tts_yo is None:
88
+ try:
89
+ tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device, token=hf_token)
90
+ logger.info("Loaded TTS (Yoruba)")
91
+ except Exception:
92
+ logger.exception("Failed to load TTS (Yoruba)")
93
+ tts_yo = None
 
 
 
94
 
95
+ tts_ig = None
96
+ logger.info("Igbo TTS model disabled - returning text responses for Igbo language")
 
97
 
98
  def _get_asr(lang_code: str):
99
  entry = asr_models.get(lang_code)
 
102
  if entry["model"] is not None and entry["proc"] is not None:
103
  return entry["model"], entry["proc"]
104
  repo_id = entry["repo"]
105
+ hf_token = os.getenv("HF_TOKEN", "").strip()
 
 
106
  try:
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
  logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...")
 
169
  text = _run_whisper(model, proc, audio_array)
170
  if text:
171
  candidates.append((code, text))
 
172
  for lang_code, text in candidates:
173
  det = detect_language(text)
174
  if lang_code == det:
175
  return text
 
176
  if candidates:
177
  return max((t for _, t in candidates), key=lambda s: len(s or ""))
178
  return ""
179
 
 
180
  def get_ai_response(text: str) -> str:
181
  try:
182
  response = requests.post(ASK_URL, json={"query": text}, timeout=30)
 
204
  global natlas_tokenizer, natlas_model
205
  if natlas_tokenizer is not None and natlas_model is not None:
206
  return True
207
+ hf_token = os.getenv("HF_TOKEN", "").strip()
 
 
 
 
208
  if not hf_token:
209
  logger.error("HF_TOKEN not available for N-ATLaS model access")
210
  return False
 
211
  try:
212
+ logger.info("Loading N-ATLaS language identification model...")
 
 
213
  natlas_tokenizer = AutoTokenizer.from_pretrained("NCAIR1/N-ATLaS", token=hf_token)
214
  natlas_model = AutoModelForCausalLM.from_pretrained(
215
  "NCAIR1/N-ATLaS",
 
229
 
230
  def detect_language(text: str) -> str:
231
  logger.info(f"Detecting language for text: '{text[:50]}...'")
 
232
  if not _load_natlas():
233
  logger.warning("N-ATLaS model not available, falling back to keyword detection")
234
  text_lower = text.lower()
235
  if any(word in text_lower for word in HAUSA_WORDS):
 
236
  return "ha"
237
  elif any(word in text_lower for word in YORUBA_WORDS):
 
238
  return "yo"
239
  elif any(word in text_lower for word in IGBO_WORDS):
 
240
  return "ig"
241
  else:
 
242
  return "en"
 
243
  try:
 
244
  messages = [
245
  {'role': 'system', 'content': 'You are a language identification assistant. Identify the language of the given text and respond with only the language code: "en" for English, "ha" for Hausa, "yo" for Yoruba, or "ig" for Igbo.'},
246
  {'role': 'user', 'content': f'What language is this text written in? "{text}"'}
247
  ]
248
+ formatted_text = natlas_tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
 
 
249
  input_tokens = natlas_tokenizer(formatted_text, return_tensors='pt', add_special_tokens=False)
250
  if torch.cuda.is_available():
251
  input_tokens = input_tokens.to('cuda')
 
252
  with torch.no_grad():
253
  outputs = natlas_model.generate(
254
  **input_tokens,
 
258
  temperature=0.1,
259
  do_sample=False
260
  )
 
261
  response = natlas_tokenizer.batch_decode(outputs)[0]
262
  response_text = response.split(messages[1]['content'])[-1].strip().lower()
 
 
 
263
  if 'ha' in response_text:
 
264
  return "ha"
265
  elif 'yo' in response_text:
 
266
  return "yo"
267
  elif 'ig' in response_text:
 
268
  return "ig"
269
  else:
 
270
  return "en"
 
271
  except Exception as e:
272
  logger.exception(f"Language detection failed: {e}")
 
273
  text_lower = text.lower()
274
  if any(word in text_lower for word in HAUSA_WORDS):
275
  return "ha"
 
281
  return "en"
282
 
283
  def text_to_speech_file(text: str) -> str:
284
+ load_models()
285
  lang = detect_language(text)
 
 
286
  if lang == "ig":
287
  logger.info("Igbo language detected - returning text response instead of audio")
288
  fd, path = tempfile.mkstemp(suffix=".txt")
 
296
  tts_model = tts_yo
297
  else:
298
  tts_model = tts_en
 
299
  if tts_model is None:
 
300
  raise HTTPException(status_code=500, detail=f"TTS model for {lang} is not available")
 
301
  speech_output = tts_model(text)
302
  audio_raw = speech_output["audio"]
303
+ sampling_rate = int(speech_output["sampling_rate"])
 
 
304
  if isinstance(audio_raw, torch.Tensor):
305
  audio_np = audio_raw.detach().cpu().numpy()
306
  else:
307
  audio_np = np.asarray(audio_raw)
 
308
  if audio_np.ndim > 1:
309
  audio_np = audio_np.reshape(-1)
310
  audio_np = audio_np.astype(np.float32, copy=False)
 
 
311
  audio_clipped = np.clip(audio_np, -1.0, 1.0)
312
  audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
 
 
313
  fd, path = tempfile.mkstemp(suffix=".wav")
314
  os.close(fd)
 
 
315
  sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16')
316
  return path
317
 
 
318
  @app.get("/")
319
  async def root():
320
  return {"status": "ok", "message": "System ready"}
 
371
  return {"transcription": transcription, "ai_response": ai_response}
372
 
373
  if __name__ == "__main__":
 
374
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))