nexusbert commited on
Commit
4b62031
·
1 Parent(s): 4f110d3

push igbo model

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. app.py +51 -8
Dockerfile CHANGED
@@ -38,10 +38,12 @@ RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(
38
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-hau')" \
39
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-eng')" \
40
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-yor')" \
 
41
  && find /models/huggingface -name '*.lock' -delete
42
 
43
  # Preload tokenizers (avoid runtime delays)
44
  RUN python -c "from transformers import Wav2Vec2Processor; Wav2Vec2Processor.from_pretrained('facebook/mms-1b-all')" \
 
45
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-hau')" \
46
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
47
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
 
38
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-hau')" \
39
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-eng')" \
40
  && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-yor')" \
41
+ && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='NCAIR1/Igbo-ASR')" \
42
  && find /models/huggingface -name '*.lock' -delete
43
 
44
  # Preload tokenizers (avoid runtime delays)
45
  RUN python -c "from transformers import Wav2Vec2Processor; Wav2Vec2Processor.from_pretrained('facebook/mms-1b-all')" \
46
+ && python -c "from transformers import WhisperProcessor; WhisperProcessor.from_pretrained('NCAIR1/Igbo-ASR')" \
47
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-hau')" \
48
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
49
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
app.py CHANGED
@@ -9,7 +9,7 @@ import soundfile as sf
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Form
10
  from fastapi.responses import FileResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
- from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC
13
  from langdetect import detect
14
  import imageio_ffmpeg
15
  import logging
@@ -43,6 +43,8 @@ tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
43
 
44
  mms_model = None
45
  mms_processor = None
 
 
46
 
47
  def load_models():
48
  global tts_ha, tts_en, tts_yo, tts_ig
@@ -74,7 +76,7 @@ def load_models():
74
  logger.info("Igbo TTS disabled: will fallback to text response")
75
 
76
 
77
- logger.info("Deferred MMS model load: will lazy-load on first use")
78
 
79
  def _get_mms():
80
  global mms_model, mms_processor
@@ -94,7 +96,36 @@ def _get_mms():
94
  except Exception:
95
  logger.exception("Failed to load MMS ASR model")
96
  mms_model, mms_processor = None, None
97
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.ndarray) -> str:
100
  try:
@@ -140,12 +171,24 @@ def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.nda
140
 
141
  def speech_to_text(audio_data: bytes) -> str:
142
  audio_array = preprocess_audio_ffmpeg(audio_data)
143
- model, proc = _get_mms()
144
- if model is None or proc is None:
145
- return ""
146
 
147
- text = _run_mms(model, proc, audio_array)
148
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def get_ai_response(text: str, response_language: str = None) -> str:
 
9
  from fastapi import FastAPI, File, UploadFile, HTTPException, Form
10
  from fastapi.responses import FileResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC, WhisperProcessor, WhisperForConditionalGeneration
13
  from langdetect import detect
14
  import imageio_ffmpeg
15
  import logging
 
43
 
44
  mms_model = None
45
  mms_processor = None
46
+ igbo_model = None
47
+ igbo_processor = None
48
 
49
  def load_models():
50
  global tts_ha, tts_en, tts_yo, tts_ig
 
76
  logger.info("Igbo TTS disabled: will fallback to text response")
77
 
78
 
79
+ logger.info("Deferred MMS and Igbo ASR model loads: will lazy-load on first use")
80
 
81
  def _get_mms():
82
  global mms_model, mms_processor
 
96
  except Exception:
97
  logger.exception("Failed to load MMS ASR model")
98
  mms_model, mms_processor = None, None
99
+ def _get_igbo_asr():
100
+ global igbo_model, igbo_processor
101
+ if igbo_model is not None and igbo_processor is not None:
102
+ return igbo_model, igbo_processor
103
+
104
+ hf_token = os.getenv("HF_TOKEN")
105
+ try:
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+ logger.info("Lazy-loading Igbo ASR model...")
108
+ igbo_processor = WhisperProcessor.from_pretrained("NCAIR1/Igbo-ASR", token=hf_token)
109
+ igbo_model = WhisperForConditionalGeneration.from_pretrained("NCAIR1/Igbo-ASR", token=hf_token)
110
+ igbo_model.to(device)
111
+ igbo_model.eval()
112
+ logger.info("Loaded Igbo ASR model")
113
+ return igbo_model, igbo_processor
114
+ except Exception:
115
+ logger.exception("Failed to load Igbo ASR model")
116
+ igbo_model, igbo_processor = None, None
117
+ def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str:
118
+ try:
119
+ device = next(model.parameters()).device
120
+ inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
121
+ input_features = inputs.input_features.to(device)
122
+ with torch.no_grad():
123
+ predicted_ids = model.generate(input_features)
124
+ text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True)
125
+ return text_list[0] if text_list else ""
126
+ except Exception:
127
+ logging.exception("Whisper ASR inference failed")
128
+ return ""
129
 
130
  def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.ndarray) -> str:
131
  try:
 
171
 
172
  def speech_to_text(audio_data: bytes) -> str:
173
  audio_array = preprocess_audio_ffmpeg(audio_data)
 
 
 
174
 
175
+ # Try Igbo ASR first for better Igbo detection
176
+ igbo_model, igbo_proc = _get_igbo_asr()
177
+ if igbo_model is not None and igbo_proc is not None:
178
+ igbo_text = _run_whisper(igbo_model, igbo_proc, audio_array)
179
+ if igbo_text and igbo_text.strip():
180
+ logger.info("Using Igbo ASR result")
181
+ return igbo_text
182
+
183
+ # Fallback to MMS for other languages
184
+ mms_model, mms_proc = _get_mms()
185
+ if mms_model is not None and mms_proc is not None:
186
+ mms_text = _run_mms(mms_model, mms_proc, audio_array)
187
+ if mms_text and mms_text.strip():
188
+ logger.info("Using MMS ASR result")
189
+ return mms_text
190
+
191
+ return ""
192
 
193
 
194
  def get_ai_response(text: str, response_language: str = None) -> str: