Spaces:
Sleeping
Sleeping
push igbo model
Browse files- Dockerfile +2 -0
- app.py +51 -8
Dockerfile
CHANGED
|
@@ -38,10 +38,12 @@ RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(
|
|
| 38 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-hau')" \
|
| 39 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-eng')" \
|
| 40 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-yor')" \
|
|
|
|
| 41 |
&& find /models/huggingface -name '*.lock' -delete
|
| 42 |
|
| 43 |
# Preload tokenizers (avoid runtime delays)
|
| 44 |
RUN python -c "from transformers import Wav2Vec2Processor; Wav2Vec2Processor.from_pretrained('facebook/mms-1b-all')" \
|
|
|
|
| 45 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-hau')" \
|
| 46 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
|
| 47 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
|
|
|
|
| 38 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-hau')" \
|
| 39 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-eng')" \
|
| 40 |
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-yor')" \
|
| 41 |
+
&& python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='NCAIR1/Igbo-ASR')" \
|
| 42 |
&& find /models/huggingface -name '*.lock' -delete
|
| 43 |
|
| 44 |
# Preload tokenizers (avoid runtime delays)
|
| 45 |
RUN python -c "from transformers import Wav2Vec2Processor; Wav2Vec2Processor.from_pretrained('facebook/mms-1b-all')" \
|
| 46 |
+
&& python -c "from transformers import WhisperProcessor; WhisperProcessor.from_pretrained('NCAIR1/Igbo-ASR')" \
|
| 47 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-hau')" \
|
| 48 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
|
| 49 |
&& python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
|
app.py
CHANGED
|
@@ -9,7 +9,7 @@ import soundfile as sf
|
|
| 9 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
| 10 |
from fastapi.responses import FileResponse
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
-
from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 13 |
from langdetect import detect
|
| 14 |
import imageio_ffmpeg
|
| 15 |
import logging
|
|
@@ -43,6 +43,8 @@ tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
|
|
| 43 |
|
| 44 |
mms_model = None
|
| 45 |
mms_processor = None
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def load_models():
|
| 48 |
global tts_ha, tts_en, tts_yo, tts_ig
|
|
@@ -74,7 +76,7 @@ def load_models():
|
|
| 74 |
logger.info("Igbo TTS disabled: will fallback to text response")
|
| 75 |
|
| 76 |
|
| 77 |
-
logger.info("Deferred MMS model
|
| 78 |
|
| 79 |
def _get_mms():
|
| 80 |
global mms_model, mms_processor
|
|
@@ -94,7 +96,36 @@ def _get_mms():
|
|
| 94 |
except Exception:
|
| 95 |
logger.exception("Failed to load MMS ASR model")
|
| 96 |
mms_model, mms_processor = None, None
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.ndarray) -> str:
|
| 100 |
try:
|
|
@@ -140,12 +171,24 @@ def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.nda
|
|
| 140 |
|
| 141 |
def speech_to_text(audio_data: bytes) -> str:
|
| 142 |
audio_array = preprocess_audio_ffmpeg(audio_data)
|
| 143 |
-
model, proc = _get_mms()
|
| 144 |
-
if model is None or proc is None:
|
| 145 |
-
return ""
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
def get_ai_response(text: str, response_language: str = None) -> str:
|
|
|
|
| 9 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
| 10 |
from fastapi.responses import FileResponse
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC, WhisperProcessor, WhisperForConditionalGeneration
|
| 13 |
from langdetect import detect
|
| 14 |
import imageio_ffmpeg
|
| 15 |
import logging
|
|
|
|
| 43 |
|
| 44 |
mms_model = None
|
| 45 |
mms_processor = None
|
| 46 |
+
igbo_model = None
|
| 47 |
+
igbo_processor = None
|
| 48 |
|
| 49 |
def load_models():
|
| 50 |
global tts_ha, tts_en, tts_yo, tts_ig
|
|
|
|
| 76 |
logger.info("Igbo TTS disabled: will fallback to text response")
|
| 77 |
|
| 78 |
|
| 79 |
+
logger.info("Deferred MMS and Igbo ASR model loads: will lazy-load on first use")
|
| 80 |
|
| 81 |
def _get_mms():
|
| 82 |
global mms_model, mms_processor
|
|
|
|
| 96 |
except Exception:
|
| 97 |
logger.exception("Failed to load MMS ASR model")
|
| 98 |
mms_model, mms_processor = None, None
|
| 99 |
+
def _get_igbo_asr():
|
| 100 |
+
global igbo_model, igbo_processor
|
| 101 |
+
if igbo_model is not None and igbo_processor is not None:
|
| 102 |
+
return igbo_model, igbo_processor
|
| 103 |
+
|
| 104 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 105 |
+
try:
|
| 106 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 107 |
+
logger.info("Lazy-loading Igbo ASR model...")
|
| 108 |
+
igbo_processor = WhisperProcessor.from_pretrained("NCAIR1/Igbo-ASR", token=hf_token)
|
| 109 |
+
igbo_model = WhisperForConditionalGeneration.from_pretrained("NCAIR1/Igbo-ASR", token=hf_token)
|
| 110 |
+
igbo_model.to(device)
|
| 111 |
+
igbo_model.eval()
|
| 112 |
+
logger.info("Loaded Igbo ASR model")
|
| 113 |
+
return igbo_model, igbo_processor
|
| 114 |
+
except Exception:
|
| 115 |
+
logger.exception("Failed to load Igbo ASR model")
|
| 116 |
+
igbo_model, igbo_processor = None, None
|
| 117 |
+
def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str:
|
| 118 |
+
try:
|
| 119 |
+
device = next(model.parameters()).device
|
| 120 |
+
inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
|
| 121 |
+
input_features = inputs.input_features.to(device)
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
predicted_ids = model.generate(input_features)
|
| 124 |
+
text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True)
|
| 125 |
+
return text_list[0] if text_list else ""
|
| 126 |
+
except Exception:
|
| 127 |
+
logging.exception("Whisper ASR inference failed")
|
| 128 |
+
return ""
|
| 129 |
|
| 130 |
def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.ndarray) -> str:
|
| 131 |
try:
|
|
|
|
| 171 |
|
| 172 |
def speech_to_text(audio_data: bytes) -> str:
|
| 173 |
audio_array = preprocess_audio_ffmpeg(audio_data)
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
# Try Igbo ASR first for better Igbo detection
|
| 176 |
+
igbo_model, igbo_proc = _get_igbo_asr()
|
| 177 |
+
if igbo_model is not None and igbo_proc is not None:
|
| 178 |
+
igbo_text = _run_whisper(igbo_model, igbo_proc, audio_array)
|
| 179 |
+
if igbo_text and igbo_text.strip():
|
| 180 |
+
logger.info("Using Igbo ASR result")
|
| 181 |
+
return igbo_text
|
| 182 |
+
|
| 183 |
+
# Fallback to MMS for other languages
|
| 184 |
+
mms_model, mms_proc = _get_mms()
|
| 185 |
+
if mms_model is not None and mms_proc is not None:
|
| 186 |
+
mms_text = _run_mms(mms_model, mms_proc, audio_array)
|
| 187 |
+
if mms_text and mms_text.strip():
|
| 188 |
+
logger.info("Using MMS ASR result")
|
| 189 |
+
return mms_text
|
| 190 |
+
|
| 191 |
+
return ""
|
| 192 |
|
| 193 |
|
| 194 |
def get_ai_response(text: str, response_language: str = None) -> str:
|