nexusbert commited on
Commit
c3c0d65
·
1 Parent(s): 06b99eb

push mileston3

Browse files
Files changed (3) hide show
  1. Dockerfile +56 -0
  2. app.py +427 -0
  3. requirements.txt +24 -0
Dockerfile ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base Image
2
+ FROM python:3.10-slim
3
+
4
+ ENV DEBIAN_FRONTEND=noninteractive \
5
+ PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1
7
+
8
+ WORKDIR /code
9
+
10
+ # System Dependencies
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ build-essential \
13
+ git \
14
+ curl \
15
+ libopenblas-dev \
16
+ libomp-dev \
17
+ ffmpeg \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Copy requirements and install Python dependencies
21
+ COPY requirements.txt .
22
+ RUN pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Hugging Face + model tools
25
+ RUN pip install --no-cache-dir huggingface-hub sentencepiece accelerate fasttext
26
+
27
+ # Hugging Face cache environment
28
+ ENV HF_HOME=/models/huggingface \
29
+ TRANSFORMERS_CACHE=/models/huggingface \
30
+ HUGGINGFACE_HUB_CACHE=/models/huggingface \
31
+ HF_HUB_CACHE=/models/huggingface
32
+
33
+ # Created cache dir and set permissions
34
+ RUN mkdir -p /models/huggingface && chmod -R 777 /models/huggingface
35
+
36
+ # Pre-download models at build time
37
+ RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-1b-all')" \
38
+ && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-hau')" \
39
+ && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-eng')" \
40
+ && python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='facebook/mms-tts-yor')" \
41
+ && find /models/huggingface -name '*.lock' -delete
42
+
43
+ # Preload tokenizers (avoid runtime delays)
44
+ RUN python -c "from transformers import Wav2Vec2Processor; Wav2Vec2Processor.from_pretrained('facebook/mms-1b-all')" \
45
+ && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-hau')" \
46
+ && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
47
+ && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
48
+
49
+ # Copy project files
50
+ COPY . .
51
+
52
+ # Expose FastAPI port
53
+ EXPOSE 7860
54
+
55
+ # Run FastAPI app with uvicorn (1 workers for concurrency)
56
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import tempfile
4
+ import subprocess
5
+ import requests
6
+ import torch
7
+ import numpy as np
8
+ import soundfile as sf
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
10
+ from fastapi.responses import FileResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2ForCTC
13
+ from langdetect import detect
14
+ import imageio_ffmpeg
15
+ import logging
16
+ from contextlib import asynccontextmanager
17
+ import uvicorn
18
+ import nest_asyncio
19
+
20
+ nest_asyncio.apply()
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ load_models()
28
+ yield
29
+
30
+ app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan)
31
+
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+
41
+ ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask"
42
+ tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
43
+
44
+ mms_model = None
45
+ mms_processor = None
46
+
47
+ def load_models():
48
+ global tts_ha, tts_en, tts_yo, tts_ig
49
+ device = 0 if torch.cuda.is_available() else -1
50
+ hf_token = os.getenv("HF_TOKEN")
51
+ if not hf_token:
52
+ logger.info("HF_TOKEN not set; gated repos may fail to load. Set HF_TOKEN to access restricted models.")
53
+ logger.info("Loading TTS models...")
54
+ try:
55
+ tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device)
56
+ logger.info("Loaded TTS (Hausa)")
57
+ except Exception as e:
58
+ logger.exception("Failed to load TTS (Hausa)")
59
+ tts_ha = None
60
+ try:
61
+ tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device)
62
+ logger.info("Loaded TTS (English)")
63
+ except Exception:
64
+ logger.exception("Failed to load TTS (English)")
65
+ tts_en = None
66
+ try:
67
+ tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device)
68
+ logger.info("Loaded TTS (Yoruba)")
69
+ except Exception:
70
+ logger.exception("Failed to load TTS (Yoruba)")
71
+ tts_yo = None
72
+
73
+ tts_ig = None
74
+ logger.info("Igbo TTS disabled: will fallback to text response")
75
+
76
+
77
+ logger.info("Deferred MMS model load: will lazy-load on first use")
78
+
79
+ def _get_mms():
80
+ global mms_model, mms_processor
81
+ if mms_model is not None and mms_processor is not None:
82
+ return mms_model, mms_processor
83
+
84
+ hf_token = os.getenv("HF_TOKEN")
85
+ try:
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ logger.info("Lazy-loading MMS ASR model...")
88
+ mms_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all", token=hf_token)
89
+ mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all", token=hf_token)
90
+ mms_model.to(device)
91
+ mms_model.eval()
92
+ logger.info("Loaded MMS ASR model")
93
+ return mms_model, mms_processor
94
+ except Exception:
95
+ logger.exception("Failed to load MMS ASR model")
96
+ mms_model, mms_processor = None, None
97
+ return None, None
98
+
99
+ def _run_mms(model: Wav2Vec2ForCTC, proc: Wav2Vec2Processor, audio_array: np.ndarray) -> str:
100
+ try:
101
+ device = next(model.parameters()).device
102
+ inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
103
+ input_values = inputs.input_values.to(device)
104
+ with torch.no_grad():
105
+ logits = model(input_values).logits
106
+ predicted_ids = torch.argmax(logits, dim=-1)
107
+ text = proc.batch_decode(predicted_ids)[0]
108
+ return text.strip() if text else ""
109
+ except Exception:
110
+ logging.exception("MMS ASR inference failed")
111
+ return ""
112
+
113
+ def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray:
114
+ try:
115
+ with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file:
116
+ in_file.write(audio_data)
117
+ in_path = in_file.name
118
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file:
119
+ out_path = out_file.name
120
+
121
+ ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
122
+ subprocess.run([
123
+ ffmpeg_exe, '-y', '-i', in_path,
124
+ '-ac', '1', '-ar', str(target_sr), out_path
125
+ ], check=True, capture_output=True)
126
+
127
+ with open(out_path, 'rb') as f:
128
+ wav_data = f.read()
129
+
130
+ os.unlink(in_path)
131
+ os.unlink(out_path)
132
+
133
+ audio_array, sr = sf.read(io.BytesIO(wav_data))
134
+ if len(audio_array.shape) > 1:
135
+ audio_array = np.mean(audio_array, axis=1)
136
+ return audio_array.astype(np.float32)
137
+ except Exception as e:
138
+ logger.error(f"FFmpeg preprocessing failed: {e}")
139
+ raise HTTPException(status_code=400, detail="Audio preprocessing failed. Ensure ffmpeg is installed.")
140
+
141
+ def speech_to_text(audio_data: bytes) -> str:
142
+ audio_array = preprocess_audio_ffmpeg(audio_data)
143
+ model, proc = _get_mms()
144
+ if model is None or proc is None:
145
+ return ""
146
+
147
+ text = _run_mms(model, proc, audio_array)
148
+ return text
149
+
150
+
151
+ def get_ai_response(text: str, response_language: str = None) -> str:
152
+ try:
153
+ if response_language and response_language != "en":
154
+ language_instructions = {
155
+ "ha": "Please respond in Hausa language.",
156
+ "yo": "Please respond in Yoruba language.",
157
+ "ig": "Please respond in Igbo language.",
158
+ "en": "Please respond in English."
159
+ }
160
+ language_instruction = language_instructions.get(response_language, "")
161
+ enhanced_query = f"{text}. {language_instruction}" if language_instruction else text
162
+ else:
163
+ enhanced_query = text
164
+
165
+ response = requests.post(ASK_URL, json={"query": enhanced_query}, timeout=30)
166
+ response.raise_for_status()
167
+ result = response.json()
168
+ return result.get("answer", "Sorry, no answer returned.")
169
+ except Exception as e:
170
+ logger.error(f"AI request error: {e}")
171
+ return f"I'm sorry, I couldn't connect to the AI service. You said: '{text}'."
172
+
173
+ HAUSA_WORDS = [
174
+ "aikin","manoma","gona","amfanin","yanayi","tsaba","fasaha","bisa","noman","shuka",
175
+ "daji","rani","damina","amfani","bidi'a","noma","bashi","manure","tsiro","gishiri"
176
+ ]
177
+
178
+ YORUBA_WORDS = [
179
+ "ilé","ọmọ","òun","awọn","agbẹ","oko","ọgbà","irugbin","àkọsílẹ","omi","ojo","àgbàlá","irọlẹ"
180
+ ]
181
+
182
+ IGBO_WORDS = [
183
+ "ugbo","akụkọ","mmiri","ala","ọrụ","ncheta","ọhụrụ","ugwu","nri","ahụhụ"
184
+ ]
185
+
186
+ def detect_language(text: str) -> str:
187
+ text_lower = text.lower()
188
+ if any(word in text_lower for word in HAUSA_WORDS):
189
+ return "ha"
190
+ elif any(word in text_lower for word in YORUBA_WORDS):
191
+ return "yo"
192
+ elif any(word in text_lower for word in IGBO_WORDS):
193
+ return "ig"
194
+ lang = detect(text)
195
+ if lang.startswith("ha"):
196
+ return "ha"
197
+ elif lang.startswith("yo"):
198
+ return "yo"
199
+ elif lang.startswith("ig"):
200
+ return "ig"
201
+ else:
202
+ return "en"
203
+
204
+ def text_to_speech_file(text: str) -> str:
205
+ lang = detect_language(text)
206
+ print(f"Detected language: {lang}")
207
+
208
+ supported_tts_languages = ["ha", "yo", "en"]
209
+ if lang not in supported_tts_languages:
210
+ logger.warning(f"Language '{lang}' not supported for TTS, falling back to English")
211
+ lang = "en"
212
+
213
+ global tts_ig
214
+ if lang == "ha":
215
+ tts_model = tts_ha
216
+ elif lang == "yo":
217
+ tts_model = tts_yo
218
+ elif lang == "ig":
219
+ logger.warning("Igbo TTS not available, raising exception for text fallback")
220
+ raise Exception("Igbo TTS not available - returning text response")
221
+ else:
222
+ tts_model = tts_en
223
+
224
+ if tts_model is None:
225
+ raise Exception(f"TTS model not available for language '{lang}'")
226
+
227
+ speech_output = tts_model(text)
228
+ audio_raw = speech_output["audio"]
229
+ sampling_rate = int(speech_output["sampling_rate"])
230
+
231
+
232
+ if isinstance(audio_raw, torch.Tensor):
233
+ audio_np = audio_raw.detach().cpu().numpy()
234
+ else:
235
+ audio_np = np.asarray(audio_raw)
236
+
237
+ if audio_np.ndim > 1:
238
+ audio_np = audio_np.reshape(-1)
239
+ audio_np = audio_np.astype(np.float32, copy=False)
240
+
241
+
242
+ audio_clipped = np.clip(audio_np, -1.0, 1.0)
243
+ audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
244
+
245
+
246
+ fd, path = tempfile.mkstemp(suffix=".wav")
247
+ os.close(fd)
248
+
249
+
250
+ sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16')
251
+ return path
252
+
253
+
254
+ @app.get("/")
255
+ async def root():
256
+ return {"status": "ok", "message": "System ready"}
257
+
258
+ @app.get("/health")
259
+ async def health():
260
+ return {"message": "Farmlingua AI Speech Interface is running!"}
261
+
262
+ @app.post("/chat")
263
+ async def chat(text: str = Form(...), speak: bool = False, raw: bool = False):
264
+ if not text.strip():
265
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
266
+
267
+ input_language = detect_language(text)
268
+ final_text = text if raw else get_ai_response(text, response_language=input_language)
269
+
270
+ if speak:
271
+ try:
272
+ audio_path = text_to_speech_file(final_text)
273
+ return FileResponse(audio_path, media_type="audio/wav", filename="response.wav")
274
+ except Exception as e:
275
+ logger.warning(f"TTS failed for chat endpoint: {e}")
276
+ return {
277
+ "question": text,
278
+ "answer": final_text,
279
+ "input_language": input_language,
280
+ "tts_available": False,
281
+ "message": f"TTS not available: {str(e)}"
282
+ }
283
+ return {
284
+ "question": text,
285
+ "answer": final_text,
286
+ "input_language": input_language
287
+ }
288
+
289
+ @app.post("/speak")
290
+ async def speak_to_ai(audio_file: UploadFile = File(...), speak: bool = True):
291
+ if not audio_file.content_type.startswith('audio/'):
292
+ raise HTTPException(status_code=400, detail="File must be an audio file")
293
+ audio_data = await audio_file.read()
294
+ transcription = speech_to_text(audio_data)
295
+
296
+ input_language = detect_language(transcription)
297
+ ai_response = get_ai_response(transcription, response_language=input_language)
298
+
299
+ if speak:
300
+ try:
301
+ audio_path = text_to_speech_file(ai_response)
302
+ return FileResponse(audio_path, media_type="audio/wav", filename="response.wav")
303
+ except Exception as e:
304
+ logger.warning(f"TTS failed for speak endpoint: {e}")
305
+ return {
306
+ "transcription": transcription,
307
+ "ai_response": ai_response,
308
+ "input_language": input_language,
309
+ "tts_available": False,
310
+ "message": f"TTS not available: {str(e)}"
311
+ }
312
+ return {
313
+ "transcription": transcription,
314
+ "ai_response": ai_response,
315
+ "input_language": input_language
316
+ }
317
+
318
+ @app.post("/stt")
319
+ async def speech_to_text_endpoint(audio_file: UploadFile = File(...)):
320
+ if not audio_file.content_type.startswith('audio/'):
321
+ raise HTTPException(status_code=400, detail="File must be an audio file")
322
+
323
+ try:
324
+ audio_data = await audio_file.read()
325
+ transcription = speech_to_text(audio_data)
326
+
327
+ if not transcription.strip():
328
+ return {"transcription": "", "error": "No speech detected or transcription failed"}
329
+
330
+ return {
331
+ "transcription": transcription,
332
+ "language_detected": detect_language(transcription),
333
+ "success": True
334
+ }
335
+ except Exception as e:
336
+ logger.error(f"STT endpoint error: {e}")
337
+ raise HTTPException(status_code=500, detail=f"Speech-to-text conversion failed: {str(e)}")
338
+
339
+ @app.post("/tts")
340
+ async def text_to_speech_endpoint(text: str = Form(...), language: str = Form(None)):
341
+ if not text.strip():
342
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
343
+
344
+ try:
345
+ if language and language in ["ha", "yo", "ig", "en"]:
346
+ lang = language
347
+ else:
348
+ lang = detect_language(text)
349
+
350
+ logger.info(f"TTS using language: {lang}")
351
+
352
+ supported_tts_languages = ["ha", "yo", "en"]
353
+ if lang not in supported_tts_languages:
354
+ logger.warning(f"Language '{lang}' not supported for TTS, returning text-only response")
355
+ return {
356
+ "text": text,
357
+ "language_detected": lang,
358
+ "tts_available": False,
359
+ "message": f"TTS not available for language '{lang}'. Supported languages: {', '.join(supported_tts_languages)}",
360
+ "note": "AI response is already in the detected language"
361
+ }
362
+
363
+ global tts_ig
364
+ if lang == "ha":
365
+ tts_model = tts_ha
366
+ elif lang == "yo":
367
+ tts_model = tts_yo
368
+ elif lang == "ig":
369
+ logger.warning("Igbo TTS not available, returning text-only response")
370
+ return {
371
+ "text": text,
372
+ "language_detected": lang,
373
+ "tts_available": False,
374
+ "message": "Igbo TTS not available - returning text response",
375
+ "note": "AI response is already in Igbo language"
376
+ }
377
+ else:
378
+ tts_model = tts_en
379
+
380
+ if tts_model is None:
381
+ logger.warning(f"TTS model not available for language '{lang}', returning text-only response")
382
+ return {
383
+ "text": text,
384
+ "language_detected": lang,
385
+ "tts_available": False,
386
+ "message": f"TTS model not available for language '{lang}'"
387
+ }
388
+
389
+ speech_output = tts_model(text)
390
+ audio_raw = speech_output["audio"]
391
+ sampling_rate = int(speech_output["sampling_rate"])
392
+
393
+ if isinstance(audio_raw, torch.Tensor):
394
+ audio_np = audio_raw.detach().cpu().numpy()
395
+ else:
396
+ audio_np = np.asarray(audio_raw)
397
+
398
+ if audio_np.ndim > 1:
399
+ audio_np = audio_np.reshape(-1)
400
+ audio_np = audio_np.astype(np.float32, copy=False)
401
+
402
+ audio_clipped = np.clip(audio_np, -1.0, 1.0)
403
+ audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
404
+
405
+ fd, path = tempfile.mkstemp(suffix=".wav")
406
+ os.close(fd)
407
+
408
+ sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16')
409
+
410
+ return FileResponse(
411
+ path,
412
+ media_type="audio/wav",
413
+ filename=f"tts_{lang}_{hash(text) % 10000}.wav"
414
+ )
415
+
416
+ except Exception as e:
417
+ logger.error(f"TTS endpoint error: {e}")
418
+ return {
419
+ "text": text,
420
+ "language_detected": lang if 'lang' in locals() else "unknown",
421
+ "tts_available": False,
422
+ "message": f"TTS conversion failed: {str(e)}"
423
+ }
424
+
425
+ if __name__ == "__main__":
426
+ import uvicorn
427
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8000")))
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ websockets
4
+ torch
5
+ torchaudio
6
+ transformers
7
+ soundfile
8
+ requests
9
+ numpy
10
+ scipy
11
+ librosa
12
+ imageio-ffmpeg
13
+ python-multipart
14
+ aiofiles
15
+ accelerate
16
+ sentencepiece
17
+ protobuf
18
+ langdetect
19
+ nest-asyncio
20
+
21
+
22
+
23
+
24
+