import os import io import tempfile import subprocess import requests import torch import numpy as np import soundfile as sf from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM import imageio_ffmpeg import logging from contextlib import asynccontextmanager import uvicorn import nest_asyncio nest_asyncio.apply() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): load_models() yield app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask" tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None asr_models = { "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None}, "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None}, "ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None}, "en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None}, } def load_models(): global tts_ha, tts_en, tts_yo, tts_ig device = 0 if torch.cuda.is_available() else -1 hf_token = os.getenv("HF_TOKEN") if hf_token: hf_token = hf_token.strip() if not hf_token: logger.warning("HF_TOKEN not set! This may cause authentication failures for gated repositories.") logger.warning("Please set HF_TOKEN environment variable to access restricted models.") else: logger.info("HF_TOKEN is set and ready for authenticated model access.") logger.info("Using lightweight keyword-based language detection (no heavy models)") logger.info("Loading TTS models...") try: tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device) logger.info("Loaded TTS (Hausa)") except Exception as e: logger.exception("Failed to load TTS (Hausa)") tts_ha = None try: tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device) logger.info("Loaded TTS (English)") except Exception: logger.exception("Failed to load TTS (English)") tts_en = None try: tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device) logger.info("Loaded TTS (Yoruba)") except Exception: logger.exception("Failed to load TTS (Yoruba)") tts_yo = None tts_ig = None logger.info("Igbo TTS model disabled - will return text responses for Igbo language") logger.info("Deferred ASR model loads: will lazy-load per language on first use") def _get_asr(lang_code: str): entry = asr_models.get(lang_code) if not entry: return None, None if entry["model"] is not None and entry["proc"] is not None: return entry["model"], entry["proc"] repo_id = entry["repo"] hf_token = os.getenv("HF_TOKEN") if hf_token: hf_token = hf_token.strip() try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...") proc = WhisperProcessor.from_pretrained(repo_id, token=hf_token) model = WhisperForConditionalGeneration.from_pretrained(repo_id, token=hf_token) model.to(device) model.eval() entry["model"], entry["proc"] = model, proc logger.info(f"Loaded ASR for {lang_code}") return model, proc except Exception: logger.exception(f"Failed to load ASR for {lang_code} ({repo_id})") entry["model"], entry["proc"] = None, None return None, None def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str: try: device = next(model.parameters()).device inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt") input_features = inputs.input_features.to(device) with torch.no_grad(): predicted_ids = model.generate(input_features) text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True) return text_list[0] if text_list else "" except Exception: logging.exception("Whisper ASR inference failed") return "" def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray: try: with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file: in_file.write(audio_data) in_path = in_file.name with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file: out_path = out_file.name ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() subprocess.run([ ffmpeg_exe, '-y', '-i', in_path, '-ac', '1', '-ar', str(target_sr), '-af', 'highpass=f=80,' + 'lowpass=f=8000,' + 'dynaudnorm=p=0.95:m=10.0,' + 'volume=1.0,' + 'aresample=resampler=soxr', out_path ], check=True, capture_output=True) with open(out_path, 'rb') as f: wav_data = f.read() os.unlink(in_path) os.unlink(out_path) audio_array, sr = sf.read(io.BytesIO(wav_data)) if len(audio_array.shape) > 1: audio_array = np.mean(audio_array, axis=1) if sr != target_sr: logger.warning(f"Audio sampling rate {sr} != target {target_sr}, applying additional resampling...") try: from scipy import signal ratio = target_sr / sr audio_array = signal.resample(audio_array, int(len(audio_array) * ratio)) logger.info(f"Successfully resampled using scipy to {target_sr}Hz") except ImportError: logger.warning("scipy not available, using numpy interpolation") ratio = target_sr / sr new_length = int(len(audio_array) * ratio) audio_array = np.interp( np.linspace(0, len(audio_array), new_length), np.arange(len(audio_array)), audio_array ) audio_array = _validate_and_normalize_audio(audio_array) logger.info(f"Audio preprocessing complete: {len(audio_array)} samples at {target_sr}Hz") return audio_array.astype(np.float32) except Exception as e: logger.error(f"FFmpeg preprocessing failed: {e}") raise HTTPException(status_code=400, detail="Audio preprocessing failed. Ensure ffmpeg is installed.") def _validate_and_normalize_audio(audio_array: np.ndarray) -> np.ndarray: rms = np.sqrt(np.mean(audio_array**2)) if rms < 0.001: logger.warning("Audio appears to be very quiet or silent") max_val = np.max(np.abs(audio_array)) if max_val > 0.95: logger.warning(f"Audio may be clipped (max: {max_val:.3f})") target_rms = 0.1 current_rms = np.sqrt(np.mean(audio_array**2)) if current_rms > 0: normalization_factor = min(target_rms / current_rms, 2.0) audio_array = audio_array * normalization_factor logger.info(f"Normalized audio RMS from {current_rms:.4f} to {np.sqrt(np.mean(audio_array**2)):.4f}") audio_array = np.clip(audio_array, -0.99, 0.99) audio_array = audio_array - np.mean(audio_array) return audio_array def chunk_audio(audio_array: np.ndarray, chunk_length: float = 10.0, overlap: float = 1.0, sample_rate: int = 16000) -> list: chunk_samples = int(chunk_length * sample_rate) overlap_samples = int(overlap * sample_rate) step_samples = chunk_samples - overlap_samples chunks = [] start = 0 while start < len(audio_array): end = min(start + chunk_samples, len(audio_array)) chunk = audio_array[start:end] fade_samples = int(0.05 * sample_rate) if len(chunk) > 2 * fade_samples: chunk[:fade_samples] *= np.linspace(0, 1, fade_samples) chunk[-fade_samples:] *= np.linspace(1, 0, fade_samples) if len(chunk) < chunk_samples: chunk = np.pad(chunk, (0, chunk_samples - len(chunk)), mode='constant') chunk_rms = np.sqrt(np.mean(chunk**2)) if chunk_rms < 0.001: logger.warning(f"Chunk {len(chunks)+1} appears to be very quiet (RMS: {chunk_rms:.6f})") chunks.append(chunk) start += step_samples if end >= len(audio_array): break logger.info(f"Split audio into {len(chunks)} chunks of {chunk_length}s each with quality preservation") return chunks def speech_to_text_with_language(audio_data: bytes, language: str) -> str: audio_array = preprocess_audio_ffmpeg(audio_data) audio_duration = len(audio_array) / 16000 logger.info(f"Audio duration: {audio_duration:.2f} seconds, processing with {language} model") model, proc = _get_asr(language) if model is None or proc is None: logger.error(f"Failed to load {language} ASR model") return "" if audio_duration <= 15: return _process_single_chunk_with_language(audio_array, model, proc, language) else: return _process_chunked_audio_with_language(audio_array, model, proc, language) def _process_single_chunk_with_language(audio_array: np.ndarray, model, proc, language: str) -> str: text = _run_whisper(model, proc, audio_array) if text and text.strip(): logger.info(f"Transcription ({language}): {text[:100]}...") return text.strip() return "" def _process_chunked_audio_with_language(audio_array: np.ndarray, model, proc, language: str) -> str: chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0) chunk_texts = [] for i, chunk in enumerate(chunks): try: text = _run_whisper(model, proc, chunk) if text and text.strip(): chunk_texts.append(text.strip()) logger.info(f"Chunk {i+1}/{len(chunks)} ({language}): {text[:50]}...") except Exception as e: logger.warning(f"Failed to process chunk {i+1} with {language}: {e}") continue if chunk_texts: combined_text = " ".join(chunk_texts) logger.info(f"Combined {language} result: {combined_text[:100]}...") return combined_text return "" def speech_to_text(audio_data: bytes) -> str: audio_array = preprocess_audio_ffmpeg(audio_data) audio_duration = len(audio_array) / 16000 logger.info(f"Audio duration: {audio_duration:.2f} seconds") if audio_duration <= 15: return _process_single_chunk(audio_array) else: return _process_chunked_audio(audio_array) def _process_single_chunk(audio_array: np.ndarray) -> str: candidates = [] for code in ["yo", "ha", "ig", "en"]: model, proc = _get_asr(code) if model is None or proc is None: continue text = _run_whisper(model, proc, audio_array) if text and text.strip(): candidates.append((code, text.strip())) if not candidates: return "" best_transcription = _select_best_transcription(candidates) return best_transcription def _process_chunked_audio(audio_array: np.ndarray) -> str: chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0) language_results = {} for code in ["yo", "ha", "ig", "en"]: model, proc = _get_asr(code) if model is None or proc is None: continue chunk_texts = [] for i, chunk in enumerate(chunks): try: text = _run_whisper(model, proc, chunk) if text and text.strip(): chunk_texts.append(text.strip()) logger.info(f"Chunk {i+1}/{len(chunks)} ({code}): {text[:50]}...") except Exception as e: logger.warning(f"Failed to process chunk {i+1} with {code}: {e}") continue if chunk_texts: combined_text = " ".join(chunk_texts) language_results[code] = combined_text logger.info(f"Combined {code} result: {combined_text[:100]}...") if not language_results: return "" best_transcription = _select_best_transcription(list(language_results.items())) return best_transcription def _select_best_transcription(candidates: list) -> str: if not candidates: return "" if len(candidates) == 1: return candidates[0][1] scored_candidates = [] for lang_code, text in candidates: score = _score_transcription_quality(text, lang_code) scored_candidates.append((score, lang_code, text)) logger.info(f"Transcription quality score for {lang_code}: {score:.2f} - '{text[:50]}...'") scored_candidates.sort(key=lambda x: x[0], reverse=True) best_score, best_lang, best_text = scored_candidates[0] logger.info(f"Selected best transcription: {best_lang} (score: {best_score:.2f}) - '{best_text[:100]}...'") return best_text def _score_transcription_quality(text: str, lang_code: str) -> float: if not text or not text.strip(): return 0.0 text = text.strip() score = 0.0 word_count = len(text.split()) if word_count == 0: return 0.0 score += min(word_count * 0.5, 10.0) char_count = len(text) score += min(char_count * 0.1, 5.0) if lang_code == "yo": yoruba_chars = sum(1 for c in text if c in "ẹọṣgb") score += yoruba_chars * 2.0 elif lang_code == "ig": igbo_chars = sum(1 for c in text if c in "ụọị") score += igbo_chars * 2.0 elif lang_code == "ha": hausa_chars = sum(1 for c in text if c in "ts") score += hausa_chars * 1.5 punctuation_penalty = text.count('?') + text.count('!') + text.count('.') score -= punctuation_penalty * 0.5 repeated_chars = sum(1 for i in range(len(text)-2) if text[i] == text[i+1] == text[i+2]) score -= repeated_chars * 1.0 return max(score, 0.0) def get_ai_response(text: str) -> str: try: response = requests.post(ASK_URL, json={"query": text}, timeout=30) response.raise_for_status() result = response.json() return result.get("answer", "Sorry, no answer returned.") except Exception as e: logger.error(f"AI request error: {e}") return f"I'm sorry, I couldn't connect to the AI service. You said: '{text}'." HAUSA_WORDS = [ "aikin", "manoma", "gona", "amfanin", "yanayi", "tsaba", "fasaha", "bisa", "noman", "shuka", "daji", "rani", "damina", "amfani", "bidi'a", "noma", "bashi", "manure", "tsiro", "gishiri", "gonaki", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar" ] YORUBA_WORDS = [ "ilé", "ọmọ", "òun", "awọn", "agbẹ", "oko", "ọgbà", "irugbin", "àkọsílẹ", "omi", "ojo", "àgbàlá", "irọlẹ", "ni", "ti", "si", "fun", "lati", "ninu", "lori", "labe", "pelu", "ati", "tabi", "sugbon", "o", "a", "e", "won", "mi", "re", "wa", "yin", "kan", "kankan", "die", "pupo", "gbogbo", "kookan", "nibi", "nibe", "igba", "akoko", "osu", "odun", "ise", "owo", "láàsìbà", "dára", "jùlẹ̀", "ìwẹ̀", "ṣe", "kú", "tún", "fi", "wo", "ẹ", "ọ", "ṣ", "gb", "gb", "gb", "gb", "gb", "gb", "gb", "gb", "jẹ", "wá", "lọ", "dúró", "sọ", "gbọ", "rí", "mọ", "fẹ", "ní", "pẹlu", "nitori", "tori", "nitori", "tori", "nitori", "tori", "nitori", "tori", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn" ] IGBO_WORDS = [ "ugbo", "akụkọ", "mmiri", "ala", "ọrụ", "ncheta", "ọhụrụ", "ugwu", "nri", "ahụhụ", "ọkụkọ", "ewu", "atụrụ", "ehi", "azụ", "osisi", "mkpụrụ", "ubi", "ọka", "ji", "akwụkwọ", "ofe", "azu", "anụ", "nnu", "mmanụ", "ngwọ", "ọgwụ", "ahịhịa", "osisi", "n'", "maka", "n'ihi", "n'ime", "n'elu", "n'okpuru", "ya", "anyị", "unu", "ha", "otu", "ọtụtụ", "ebe", "oge", "ụ", "ọ", "ị", "bụ", "nọ", "ga", "dị", "ka", "ma" ] ENGLISH_WORDS = [ "farm", "farming", "agriculture", "crop", "crops", "plant", "plants", "seed", "seeds", "soil", "water", "rain", "weather", "harvest", "yield", "field", "fields", "farmer", "farmers", "grow", "growing", "fertilizer", "pesticide", "irrigation", "livestock", "cattle", "chicken", "goat", "sheep", "maize", "corn", "rice", "wheat", "vegetable", "vegetables", "fruit", "fruits", "tree", "trees", "cultivate", "cultivation", "plow", "plowing", "sow", "sowing", "reap", "reaping", "season", "seasons" ] def detect_language_keywords(text: str) -> str: text_lower = text.lower().strip() if not text_lower: return "en" hausa_count = sum(1 for word in HAUSA_WORDS if word in text_lower) yoruba_count = sum(1 for word in YORUBA_WORDS if word in text_lower) igbo_count = sum(1 for word in IGBO_WORDS if word in text_lower) english_count = sum(1 for word in ENGLISH_WORDS if word in text_lower) logger.info(f"Language detection scores - Hausa: {hausa_count}, Yoruba: {yoruba_count}, Igbo: {igbo_count}, English: {english_count}") if hausa_count > yoruba_count and hausa_count > igbo_count and hausa_count > english_count: logger.info("Keyword detection: Hausa") return "ha" elif yoruba_count > hausa_count and yoruba_count > igbo_count and yoruba_count > english_count: logger.info("Keyword detection: Yoruba") return "yo" elif igbo_count > hausa_count and igbo_count > yoruba_count and igbo_count > english_count: logger.info("Keyword detection: Igbo") return "ig" else: logger.info("Keyword detection: English (default)") return "en" def detect_language(text: str) -> str: logger.info(f"Detecting language for text: '{text[:50]}...'") return detect_language_keywords(text) def text_to_speech_file(text: str) -> str: lang = detect_language(text) print(f"Detected language: {lang}") if lang == "ig": logger.info("Igbo language detected - returning text response instead of audio") fd, path = tempfile.mkstemp(suffix=".txt") os.close(fd) with open(path, 'w', encoding='utf-8') as f: f.write(text) return path if lang == "ha": tts_model = tts_ha elif lang == "yo": tts_model = tts_yo else: tts_model = tts_en if tts_model is None: logger.error(f"TTS model for {lang} is not available") raise HTTPException(status_code=500, detail=f"TTS model for {lang} is not available") speech_output = tts_model(text) audio_raw = speech_output["audio"] sampling_rate = int(speech_output["sampling_rate"]) if isinstance(audio_raw, torch.Tensor): audio_np = audio_raw.detach().cpu().numpy() else: audio_np = np.asarray(audio_raw) if audio_np.ndim > 1: audio_np = audio_np.reshape(-1) audio_np = audio_np.astype(np.float32, copy=False) target_sr = 16000 if sampling_rate != target_sr: logger.info(f"Resampling TTS audio from {sampling_rate}Hz to {target_sr}Hz") ratio = target_sr / sampling_rate new_length = int(len(audio_np) * ratio) audio_np = np.interp( np.linspace(0, len(audio_np), new_length), np.arange(len(audio_np)), audio_np ) sampling_rate = target_sr audio_clipped = np.clip(audio_np, -1.0, 1.0) audio_int16 = (audio_clipped * 32767.0).astype(np.int16) fd, path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16') return path @app.get("/") async def root(): return {"status": "ok", "message": "System ready"} @app.get("/health") async def health(): return { "message": "Farmlingua AI Speech Interface is running!", "language_detection": "keyword-based (lightweight)", "tts_models": { "hausa": tts_ha is not None, "english": tts_en is not None, "yoruba": tts_yo is not None, "igbo": False } } @app.get("/status") async def status(): return { "language_detection": "keyword-based (lightweight)", "status": "ready", "message": "Using lightweight keyword-based language detection - no heavy models required" } @app.get("/languages") async def get_languages(): return { "supported_languages": { "yo": { "name": "Yoruba", "code": "yo", "tts_available": True, "asr_model": "NCAIR1/Yoruba-ASR" }, "ha": { "name": "Hausa", "code": "ha", "tts_available": True, "asr_model": "NCAIR1/Hausa-ASR" }, "ig": { "name": "Igbo", "code": "ig", "tts_available": False, "asr_model": "NCAIR1/Igbo-ASR", "note": "Text response only - TTS not available" }, "en": { "name": "English", "code": "en", "tts_available": True, "asr_model": "NCAIR1/NigerianAccentedEnglish" } }, "usage": { "specify_language": "Use /speak endpoint with language parameter for best accuracy", "auto_detection": "Use /speak-auto endpoint when language is unknown", "igbo_response": "Igbo responses return text only (no audio TTS available)" } } @app.post("/chat") async def chat(text: str = Form(...), speak: bool = False, raw: bool = False): if not text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") final_text = text if raw else get_ai_response(text) if speak: output_path = text_to_speech_file(final_text) lang = detect_language(final_text) if lang == "ig": return FileResponse(output_path, media_type="text/plain", filename="response.txt") else: return FileResponse(output_path, media_type="audio/wav", filename="response.wav") return {"question": text, "answer": final_text} @app.post("/speak") async def speak_to_ai(audio_file: UploadFile = File(...), language: str = Form(...), speak: bool = True): if not audio_file.content_type.startswith('audio/'): raise HTTPException(status_code=400, detail="File must be an audio file") if language not in ["yo", "ha", "ig", "en"]: raise HTTPException(status_code=400, detail="Language must be one of: yo (Yoruba), ha (Hausa), ig (Igbo), en (English)") audio_data = await audio_file.read() transcription = speech_to_text_with_language(audio_data, language) ai_response = get_ai_response(transcription) if speak: if language == "ig": return {"transcription": transcription, "ai_response": ai_response, "language": language, "response_type": "text", "note": "Igbo TTS not available - returning text response"} else: output_path = text_to_speech_file(ai_response) return FileResponse(output_path, media_type="audio/wav", filename="response.wav") return {"transcription": transcription, "ai_response": ai_response, "language": language} @app.post("/speak-auto") async def speak_to_ai_auto(audio_file: UploadFile = File(...), speak: bool = True): if not audio_file.content_type.startswith('audio/'): raise HTTPException(status_code=400, detail="File must be an audio file") audio_data = await audio_file.read() transcription = speech_to_text(audio_data) ai_response = get_ai_response(transcription) if speak: detected_lang = detect_language(ai_response) if detected_lang == "ig": return {"transcription": transcription, "ai_response": ai_response, "detected_language": detected_lang, "response_type": "text", "note": "Igbo TTS not available - returning text response"} else: output_path = text_to_speech_file(ai_response) return FileResponse(output_path, media_type="audio/wav", filename="response.wav") return {"transcription": transcription, "ai_response": ai_response} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))