nexusbert commited on
Commit
4f2310e
·
1 Parent(s): 652b643
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -122,13 +122,27 @@ def _get_igbo_asr():
122
  logger.exception(f"Failed to load Igbo ASR model: {e}")
123
  igbo_model, igbo_processor = None, None
124
  return None, None
125
- def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str:
126
  try:
127
  device = next(model.parameters()).device
128
  inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
129
  input_features = inputs.input_features.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with torch.no_grad():
131
- predicted_ids = model.generate(input_features)
132
  text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True)
133
  return text_list[0] if text_list else ""
134
  except Exception:
@@ -184,7 +198,7 @@ def speech_to_text(audio_data: bytes) -> str:
184
  igbo_result = _get_igbo_asr()
185
  if igbo_result[0] is not None and igbo_result[1] is not None:
186
  igbo_model, igbo_proc = igbo_result
187
- igbo_text = _run_whisper(igbo_model, igbo_proc, audio_array)
188
  if igbo_text and igbo_text.strip():
189
  logger.info("Using Igbo ASR result")
190
  return igbo_text
@@ -233,7 +247,10 @@ YORUBA_WORDS = [
233
  ]
234
 
235
  IGBO_WORDS = [
236
- "ugbo","akụkọ","mmiri","ala","ọrụ","ncheta","ọhụrụ","ugwu","nri","ahụhụ"
 
 
 
237
  ]
238
 
239
  def detect_language(text: str) -> str:
 
122
  logger.exception(f"Failed to load Igbo ASR model: {e}")
123
  igbo_model, igbo_processor = None, None
124
  return None, None
125
+ def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray, language: str = None) -> str:
126
  try:
127
  device = next(model.parameters()).device
128
  inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
129
  input_features = inputs.input_features.to(device)
130
+
131
+
132
+ generation_kwargs = {
133
+ "max_length": 448,
134
+ "num_beams": 1,
135
+ "do_sample": False,
136
+ "early_stopping": True
137
+ }
138
+
139
+
140
+ if language == "igbo" or "igbo" in str(model.config).lower():
141
+ generation_kwargs["language"] = "igbo"
142
+ generation_kwargs["task"] = "transcribe"
143
+
144
  with torch.no_grad():
145
+ predicted_ids = model.generate(input_features, **generation_kwargs)
146
  text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True)
147
  return text_list[0] if text_list else ""
148
  except Exception:
 
198
  igbo_result = _get_igbo_asr()
199
  if igbo_result[0] is not None and igbo_result[1] is not None:
200
  igbo_model, igbo_proc = igbo_result
201
+ igbo_text = _run_whisper(igbo_model, igbo_proc, audio_array, language="igbo")
202
  if igbo_text and igbo_text.strip():
203
  logger.info("Using Igbo ASR result")
204
  return igbo_text
 
247
  ]
248
 
249
  IGBO_WORDS = [
250
+ "ugbo","akụkọ","mmiri","ala","ọrụ","ncheta","ọhụrụ","ugwu","nri","ahụhụ",
251
+ "kedu","ka","si","na","bụ","nke","a","na","ọ","bụ","na","n'ime","n'elu","n'okpuru",
252
+ "n'akụkụ","n'ebe","n'ụlọ","n'ọfịs","n'ụlọ","n'ime","n'elu","n'okpuru","n'akụkụ",
253
+ "ọrụ","ugbo","mmiri","ala","nri","ahụhụ","ọhụrụ","ncheta","akụkọ","ugwu"
254
  ]
255
 
256
  def detect_language(text: str) -> str: