nexusbert commited on
Commit
32ad752
·
1 Parent(s): 63703a0
Files changed (1) hide show
  1. app.py +107 -34
app.py CHANGED
@@ -6,6 +6,7 @@ import requests
6
  import torch
7
  import numpy as np
8
  import soundfile as sf
 
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Form
10
  from fastapi.responses import FileResponse
11
  from fastapi.middleware.cors import CORSMiddleware
@@ -164,6 +165,41 @@ def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.nda
164
  logging.exception("MMS ASR inference failed")
165
  return ""
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray:
168
  try:
169
  with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file:
@@ -194,40 +230,77 @@ def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.nda
194
 
195
 
196
  def speech_to_text(audio_data: bytes) -> str:
197
- audio_array = preprocess_audio_ffmpeg(audio_data)
198
- candidates = []
199
-
200
- mms_result = _get_mms()
201
- if mms_result and mms_result[0] is not None and mms_result[1] is not None:
202
- mms_model, mms_proc = mms_result
203
- mms_text = _run_mms(mms_model, mms_proc, audio_array)
204
- if mms_text:
205
- candidates.append(("mms", mms_text))
206
- logger.info(f"MMS result: '{mms_text}'")
207
-
208
- igbo_result = _get_igbo_asr()
209
- if igbo_result[0] is not None and igbo_result[1] is not None:
210
- igbo_model, igbo_proc = igbo_result
211
- igbo_text = _run_whisper(igbo_model, igbo_proc, audio_array, language="igbo")
212
- if igbo_text:
213
- candidates.append(("igbo", igbo_text))
214
- logger.info(f"Igbo ASR result: '{igbo_text}'")
215
-
216
- for model_name, text in candidates:
217
- detected_lang = detect_language(text)
218
- if detected_lang == "ig" and model_name == "igbo":
219
- logger.info(f"Using {model_name} ASR result (detected {detected_lang} language)")
220
- return text
221
- elif detected_lang in ["ha", "yo", "en"] and model_name == "mms":
222
- logger.info(f"Using {model_name} ASR result (detected {detected_lang} language)")
223
- return text
224
-
225
- if candidates:
226
- best_text = max((t for _, t in candidates), key=lambda s: len(s or ""))
227
- logger.info(f"Using best result by length: '{best_text}'")
228
- return best_text
229
-
230
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  def get_ai_response(text: str, response_language: str = None) -> str:
 
6
  import torch
7
  import numpy as np
8
  import soundfile as sf
9
+ import torchaudio
10
  from fastapi import FastAPI, File, UploadFile, HTTPException, Form
11
  from fastapi.responses import FileResponse
12
  from fastapi.middleware.cors import CORSMiddleware
 
165
  logging.exception("MMS ASR inference failed")
166
  return ""
167
 
168
+ def chunk_audio(audio_data: bytes, chunk_len: int = 15) -> list:
169
+ """Split audio into smaller chunks for better transcription."""
170
+ try:
171
+ with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file:
172
+ in_file.write(audio_data)
173
+ in_path = in_file.name
174
+
175
+ waveform, sr = torchaudio.load(in_path)
176
+
177
+ if sr != 16000:
178
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
179
+ sr = 16000
180
+
181
+ if waveform.shape[0] > 1:
182
+ waveform = waveform.mean(dim=0, keepdim=True)
183
+
184
+ num_samples = waveform.size(1)
185
+ chunk_size = sr * chunk_len
186
+ chunks = []
187
+
188
+ for i in range(0, num_samples, chunk_size):
189
+ chunk_waveform = waveform[:, i:i+chunk_size]
190
+ if chunk_waveform.size(1) == 0:
191
+ continue
192
+
193
+ chunk_path = tempfile.mktemp(suffix=f"_chunk_{i//chunk_size}.wav")
194
+ torchaudio.save(chunk_path, chunk_waveform, sr)
195
+ chunks.append(chunk_path)
196
+
197
+ os.unlink(in_path)
198
+ return chunks
199
+ except Exception as e:
200
+ logger.error(f"Audio chunking failed: {e}")
201
+ raise HTTPException(status_code=400, detail="Audio chunking failed.")
202
+
203
  def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray:
204
  try:
205
  with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file:
 
230
 
231
 
232
  def speech_to_text(audio_data: bytes) -> str:
233
+ """Transcribe audio using chunking technique for better accuracy."""
234
+ try:
235
+ chunks = chunk_audio(audio_data, chunk_len=15)
236
+ logger.info(f"Split audio into {len(chunks)} chunks")
237
+
238
+ candidates = []
239
+
240
+ mms_result = _get_mms()
241
+ if mms_result and mms_result[0] is not None and mms_result[1] is not None:
242
+ mms_model, mms_proc = mms_result
243
+ mms_full_text = ""
244
+
245
+ for chunk_path in chunks:
246
+ try:
247
+ waveform, sr = torchaudio.load(chunk_path)
248
+ audio_array = waveform.squeeze().numpy()
249
+ chunk_text = _run_mms(mms_model, mms_proc, audio_array)
250
+ if chunk_text:
251
+ mms_full_text += " " + chunk_text
252
+ except Exception as e:
253
+ logger.warning(f"MMS chunk processing failed: {e}")
254
+ continue
255
+
256
+ if mms_full_text.strip():
257
+ candidates.append(("mms", mms_full_text.strip()))
258
+ logger.info(f"MMS result: '{mms_full_text.strip()}'")
259
+
260
+ igbo_result = _get_igbo_asr()
261
+ if igbo_result[0] is not None and igbo_result[1] is not None:
262
+ igbo_model, igbo_proc = igbo_result
263
+ igbo_full_text = ""
264
+
265
+ for chunk_path in chunks:
266
+ try:
267
+ waveform, sr = torchaudio.load(chunk_path)
268
+ audio_array = waveform.squeeze().numpy()
269
+ chunk_text = _run_whisper(igbo_model, igbo_proc, audio_array, language="igbo")
270
+ if chunk_text:
271
+ igbo_full_text += " " + chunk_text
272
+ except Exception as e:
273
+ logger.warning(f"Igbo ASR chunk processing failed: {e}")
274
+ continue
275
+
276
+ if igbo_full_text.strip():
277
+ candidates.append(("igbo", igbo_full_text.strip()))
278
+ logger.info(f"Igbo ASR result: '{igbo_full_text.strip()}'")
279
+
280
+ for chunk_path in chunks:
281
+ try:
282
+ os.unlink(chunk_path)
283
+ except:
284
+ pass
285
+
286
+ for model_name, text in candidates:
287
+ detected_lang = detect_language(text)
288
+ if detected_lang == "ig" and model_name == "igbo":
289
+ logger.info(f"Using {model_name} ASR result (detected {detected_lang} language)")
290
+ return text
291
+ elif detected_lang in ["ha", "yo", "en"] and model_name == "mms":
292
+ logger.info(f"Using {model_name} ASR result (detected {detected_lang} language)")
293
+ return text
294
+
295
+ if candidates:
296
+ best_text = max((t for _, t in candidates), key=lambda s: len(s or ""))
297
+ logger.info(f"Using best result by length: '{best_text}'")
298
+ return best_text
299
+
300
+ return ""
301
+ except Exception as e:
302
+ logger.error(f"Speech-to-text chunking failed: {e}")
303
+ return ""
304
 
305
 
306
  def get_ai_response(text: str, response_language: str = None) -> str: