Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import tempfile | |
| import subprocess | |
| import requests | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM | |
| import imageio_ffmpeg | |
| import logging | |
| from contextlib import asynccontextmanager | |
| import uvicorn | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| async def lifespan(app: FastAPI): | |
| load_models() | |
| yield | |
| app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask" | |
| tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None | |
| asr_models = { | |
| "ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None}, | |
| "yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None}, | |
| "ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None}, | |
| "en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None}, | |
| } | |
| def load_models(): | |
| global tts_ha, tts_en, tts_yo, tts_ig | |
| device = 0 if torch.cuda.is_available() else -1 | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| hf_token = hf_token.strip() | |
| if not hf_token: | |
| logger.warning("HF_TOKEN not set! This may cause authentication failures for gated repositories.") | |
| logger.warning("Please set HF_TOKEN environment variable to access restricted models.") | |
| else: | |
| logger.info("HF_TOKEN is set and ready for authenticated model access.") | |
| logger.info("Using lightweight keyword-based language detection (no heavy models)") | |
| logger.info("Loading TTS models...") | |
| try: | |
| tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device) | |
| logger.info("Loaded TTS (Hausa)") | |
| except Exception as e: | |
| logger.exception("Failed to load TTS (Hausa)") | |
| tts_ha = None | |
| try: | |
| tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device) | |
| logger.info("Loaded TTS (English)") | |
| except Exception: | |
| logger.exception("Failed to load TTS (English)") | |
| tts_en = None | |
| try: | |
| tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device) | |
| logger.info("Loaded TTS (Yoruba)") | |
| except Exception: | |
| logger.exception("Failed to load TTS (Yoruba)") | |
| tts_yo = None | |
| tts_ig = None | |
| logger.info("Igbo TTS model disabled - will return text responses for Igbo language") | |
| logger.info("Deferred ASR model loads: will lazy-load per language on first use") | |
| def _get_asr(lang_code: str): | |
| entry = asr_models.get(lang_code) | |
| if not entry: | |
| return None, None | |
| if entry["model"] is not None and entry["proc"] is not None: | |
| return entry["model"], entry["proc"] | |
| repo_id = entry["repo"] | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| hf_token = hf_token.strip() | |
| try: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...") | |
| proc = WhisperProcessor.from_pretrained(repo_id, token=hf_token) | |
| model = WhisperForConditionalGeneration.from_pretrained(repo_id, token=hf_token) | |
| model.to(device) | |
| model.eval() | |
| entry["model"], entry["proc"] = model, proc | |
| logger.info(f"Loaded ASR for {lang_code}") | |
| return model, proc | |
| except Exception: | |
| logger.exception(f"Failed to load ASR for {lang_code} ({repo_id})") | |
| entry["model"], entry["proc"] = None, None | |
| return None, None | |
| def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str: | |
| try: | |
| device = next(model.parameters()).device | |
| inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt") | |
| input_features = inputs.input_features.to(device) | |
| with torch.no_grad(): | |
| predicted_ids = model.generate(input_features) | |
| text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True) | |
| return text_list[0] if text_list else "" | |
| except Exception: | |
| logging.exception("Whisper ASR inference failed") | |
| return "" | |
| def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray: | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file: | |
| in_file.write(audio_data) | |
| in_path = in_file.name | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file: | |
| out_path = out_file.name | |
| ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe() | |
| subprocess.run([ | |
| ffmpeg_exe, '-y', '-i', in_path, | |
| '-ac', '1', | |
| '-ar', str(target_sr), | |
| '-af', | |
| 'highpass=f=80,' + | |
| 'lowpass=f=8000,' + | |
| 'dynaudnorm=p=0.95:m=10.0,' + | |
| 'volume=1.0,' + | |
| 'aresample=resampler=soxr', | |
| out_path | |
| ], check=True, capture_output=True) | |
| with open(out_path, 'rb') as f: | |
| wav_data = f.read() | |
| os.unlink(in_path) | |
| os.unlink(out_path) | |
| audio_array, sr = sf.read(io.BytesIO(wav_data)) | |
| if len(audio_array.shape) > 1: | |
| audio_array = np.mean(audio_array, axis=1) | |
| if sr != target_sr: | |
| logger.warning(f"Audio sampling rate {sr} != target {target_sr}, applying additional resampling...") | |
| try: | |
| from scipy import signal | |
| ratio = target_sr / sr | |
| audio_array = signal.resample(audio_array, int(len(audio_array) * ratio)) | |
| logger.info(f"Successfully resampled using scipy to {target_sr}Hz") | |
| except ImportError: | |
| logger.warning("scipy not available, using numpy interpolation") | |
| ratio = target_sr / sr | |
| new_length = int(len(audio_array) * ratio) | |
| audio_array = np.interp( | |
| np.linspace(0, len(audio_array), new_length), | |
| np.arange(len(audio_array)), | |
| audio_array | |
| ) | |
| audio_array = _validate_and_normalize_audio(audio_array) | |
| logger.info(f"Audio preprocessing complete: {len(audio_array)} samples at {target_sr}Hz") | |
| return audio_array.astype(np.float32) | |
| except Exception as e: | |
| logger.error(f"FFmpeg preprocessing failed: {e}") | |
| raise HTTPException(status_code=400, detail="Audio preprocessing failed. Ensure ffmpeg is installed.") | |
| def _validate_and_normalize_audio(audio_array: np.ndarray) -> np.ndarray: | |
| rms = np.sqrt(np.mean(audio_array**2)) | |
| if rms < 0.001: | |
| logger.warning("Audio appears to be very quiet or silent") | |
| max_val = np.max(np.abs(audio_array)) | |
| if max_val > 0.95: | |
| logger.warning(f"Audio may be clipped (max: {max_val:.3f})") | |
| target_rms = 0.1 | |
| current_rms = np.sqrt(np.mean(audio_array**2)) | |
| if current_rms > 0: | |
| normalization_factor = min(target_rms / current_rms, 2.0) | |
| audio_array = audio_array * normalization_factor | |
| logger.info(f"Normalized audio RMS from {current_rms:.4f} to {np.sqrt(np.mean(audio_array**2)):.4f}") | |
| audio_array = np.clip(audio_array, -0.99, 0.99) | |
| audio_array = audio_array - np.mean(audio_array) | |
| return audio_array | |
| def chunk_audio(audio_array: np.ndarray, chunk_length: float = 10.0, overlap: float = 1.0, sample_rate: int = 16000) -> list: | |
| chunk_samples = int(chunk_length * sample_rate) | |
| overlap_samples = int(overlap * sample_rate) | |
| step_samples = chunk_samples - overlap_samples | |
| chunks = [] | |
| start = 0 | |
| while start < len(audio_array): | |
| end = min(start + chunk_samples, len(audio_array)) | |
| chunk = audio_array[start:end] | |
| fade_samples = int(0.05 * sample_rate) | |
| if len(chunk) > 2 * fade_samples: | |
| chunk[:fade_samples] *= np.linspace(0, 1, fade_samples) | |
| chunk[-fade_samples:] *= np.linspace(1, 0, fade_samples) | |
| if len(chunk) < chunk_samples: | |
| chunk = np.pad(chunk, (0, chunk_samples - len(chunk)), mode='constant') | |
| chunk_rms = np.sqrt(np.mean(chunk**2)) | |
| if chunk_rms < 0.001: | |
| logger.warning(f"Chunk {len(chunks)+1} appears to be very quiet (RMS: {chunk_rms:.6f})") | |
| chunks.append(chunk) | |
| start += step_samples | |
| if end >= len(audio_array): | |
| break | |
| logger.info(f"Split audio into {len(chunks)} chunks of {chunk_length}s each with quality preservation") | |
| return chunks | |
| def speech_to_text_with_language(audio_data: bytes, language: str) -> str: | |
| audio_array = preprocess_audio_ffmpeg(audio_data) | |
| audio_duration = len(audio_array) / 16000 | |
| logger.info(f"Audio duration: {audio_duration:.2f} seconds, processing with {language} model") | |
| model, proc = _get_asr(language) | |
| if model is None or proc is None: | |
| logger.error(f"Failed to load {language} ASR model") | |
| return "" | |
| if audio_duration <= 15: | |
| return _process_single_chunk_with_language(audio_array, model, proc, language) | |
| else: | |
| return _process_chunked_audio_with_language(audio_array, model, proc, language) | |
| def _process_single_chunk_with_language(audio_array: np.ndarray, model, proc, language: str) -> str: | |
| text = _run_whisper(model, proc, audio_array) | |
| if text and text.strip(): | |
| logger.info(f"Transcription ({language}): {text[:100]}...") | |
| return text.strip() | |
| return "" | |
| def _process_chunked_audio_with_language(audio_array: np.ndarray, model, proc, language: str) -> str: | |
| chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0) | |
| chunk_texts = [] | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| text = _run_whisper(model, proc, chunk) | |
| if text and text.strip(): | |
| chunk_texts.append(text.strip()) | |
| logger.info(f"Chunk {i+1}/{len(chunks)} ({language}): {text[:50]}...") | |
| except Exception as e: | |
| logger.warning(f"Failed to process chunk {i+1} with {language}: {e}") | |
| continue | |
| if chunk_texts: | |
| combined_text = " ".join(chunk_texts) | |
| logger.info(f"Combined {language} result: {combined_text[:100]}...") | |
| return combined_text | |
| return "" | |
| def speech_to_text(audio_data: bytes) -> str: | |
| audio_array = preprocess_audio_ffmpeg(audio_data) | |
| audio_duration = len(audio_array) / 16000 | |
| logger.info(f"Audio duration: {audio_duration:.2f} seconds") | |
| if audio_duration <= 15: | |
| return _process_single_chunk(audio_array) | |
| else: | |
| return _process_chunked_audio(audio_array) | |
| def _process_single_chunk(audio_array: np.ndarray) -> str: | |
| candidates = [] | |
| for code in ["yo", "ha", "ig", "en"]: | |
| model, proc = _get_asr(code) | |
| if model is None or proc is None: | |
| continue | |
| text = _run_whisper(model, proc, audio_array) | |
| if text and text.strip(): | |
| candidates.append((code, text.strip())) | |
| if not candidates: | |
| return "" | |
| best_transcription = _select_best_transcription(candidates) | |
| return best_transcription | |
| def _process_chunked_audio(audio_array: np.ndarray) -> str: | |
| chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0) | |
| language_results = {} | |
| for code in ["yo", "ha", "ig", "en"]: | |
| model, proc = _get_asr(code) | |
| if model is None or proc is None: | |
| continue | |
| chunk_texts = [] | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| text = _run_whisper(model, proc, chunk) | |
| if text and text.strip(): | |
| chunk_texts.append(text.strip()) | |
| logger.info(f"Chunk {i+1}/{len(chunks)} ({code}): {text[:50]}...") | |
| except Exception as e: | |
| logger.warning(f"Failed to process chunk {i+1} with {code}: {e}") | |
| continue | |
| if chunk_texts: | |
| combined_text = " ".join(chunk_texts) | |
| language_results[code] = combined_text | |
| logger.info(f"Combined {code} result: {combined_text[:100]}...") | |
| if not language_results: | |
| return "" | |
| best_transcription = _select_best_transcription(list(language_results.items())) | |
| return best_transcription | |
| def _select_best_transcription(candidates: list) -> str: | |
| if not candidates: | |
| return "" | |
| if len(candidates) == 1: | |
| return candidates[0][1] | |
| scored_candidates = [] | |
| for lang_code, text in candidates: | |
| score = _score_transcription_quality(text, lang_code) | |
| scored_candidates.append((score, lang_code, text)) | |
| logger.info(f"Transcription quality score for {lang_code}: {score:.2f} - '{text[:50]}...'") | |
| scored_candidates.sort(key=lambda x: x[0], reverse=True) | |
| best_score, best_lang, best_text = scored_candidates[0] | |
| logger.info(f"Selected best transcription: {best_lang} (score: {best_score:.2f}) - '{best_text[:100]}...'") | |
| return best_text | |
| def _score_transcription_quality(text: str, lang_code: str) -> float: | |
| if not text or not text.strip(): | |
| return 0.0 | |
| text = text.strip() | |
| score = 0.0 | |
| word_count = len(text.split()) | |
| if word_count == 0: | |
| return 0.0 | |
| score += min(word_count * 0.5, 10.0) | |
| char_count = len(text) | |
| score += min(char_count * 0.1, 5.0) | |
| if lang_code == "yo": | |
| yoruba_chars = sum(1 for c in text if c in "ẹọṣgb") | |
| score += yoruba_chars * 2.0 | |
| elif lang_code == "ig": | |
| igbo_chars = sum(1 for c in text if c in "ụọị") | |
| score += igbo_chars * 2.0 | |
| elif lang_code == "ha": | |
| hausa_chars = sum(1 for c in text if c in "ts") | |
| score += hausa_chars * 1.5 | |
| punctuation_penalty = text.count('?') + text.count('!') + text.count('.') | |
| score -= punctuation_penalty * 0.5 | |
| repeated_chars = sum(1 for i in range(len(text)-2) if text[i] == text[i+1] == text[i+2]) | |
| score -= repeated_chars * 1.0 | |
| return max(score, 0.0) | |
| def get_ai_response(text: str) -> str: | |
| try: | |
| response = requests.post(ASK_URL, json={"query": text}, timeout=30) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("answer", "Sorry, no answer returned.") | |
| except Exception as e: | |
| logger.error(f"AI request error: {e}") | |
| return f"I'm sorry, I couldn't connect to the AI service. You said: '{text}'." | |
| HAUSA_WORDS = [ | |
| "aikin", "manoma", "gona", "amfanin", "yanayi", "tsaba", "fasaha", "bisa", "noman", "shuka", | |
| "daji", "rani", "damina", "amfani", "bidi'a", "noma", "bashi", "manure", "tsiro", "gishiri", | |
| "gonaki", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", | |
| "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", | |
| "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", | |
| "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar" | |
| ] | |
| YORUBA_WORDS = [ | |
| "ilé", "ọmọ", "òun", "awọn", "agbẹ", "oko", "ọgbà", "irugbin", "àkọsílẹ", "omi", "ojo", "àgbàlá", "irọlẹ", | |
| "ni", "ti", "si", "fun", "lati", "ninu", "lori", "labe", "pelu", "ati", "tabi", "sugbon", | |
| "o", "a", "e", "won", "mi", "re", "wa", "yin", | |
| "kan", "kankan", "die", "pupo", "gbogbo", "kookan", | |
| "nibi", "nibe", "igba", "akoko", "osu", "odun", "ise", "owo", | |
| "láàsìbà", "dára", "jùlẹ̀", "ìwẹ̀", "ṣe", "kú", "tún", "fi", "wo", | |
| "ẹ", "ọ", "ṣ", "gb", "gb", "gb", "gb", "gb", "gb", "gb", "gb", | |
| "jẹ", "wá", "lọ", "dúró", "sọ", "gbọ", "rí", "mọ", "fẹ", "ní", | |
| "pẹlu", "nitori", "tori", "nitori", "tori", "nitori", "tori", "nitori", "tori", | |
| "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn" | |
| ] | |
| IGBO_WORDS = [ | |
| "ugbo", "akụkọ", "mmiri", "ala", "ọrụ", "ncheta", "ọhụrụ", "ugwu", "nri", "ahụhụ", | |
| "ọkụkọ", "ewu", "atụrụ", "ehi", "azụ", "osisi", "mkpụrụ", "ubi", "ọka", "ji", | |
| "akwụkwọ", "ofe", "azu", "anụ", "nnu", "mmanụ", "ngwọ", "ọgwụ", "ahịhịa", "osisi", | |
| "n'", "maka", "n'ihi", "n'ime", "n'elu", "n'okpuru", "ya", "anyị", "unu", "ha", | |
| "otu", "ọtụtụ", "ebe", "oge", "ụ", "ọ", "ị", "bụ", "nọ", "ga", "dị", "ka", "ma" | |
| ] | |
| ENGLISH_WORDS = [ | |
| "farm", "farming", "agriculture", "crop", "crops", "plant", "plants", "seed", "seeds", "soil", | |
| "water", "rain", "weather", "harvest", "yield", "field", "fields", "farmer", "farmers", "grow", | |
| "growing", "fertilizer", "pesticide", "irrigation", "livestock", "cattle", "chicken", "goat", "sheep", | |
| "maize", "corn", "rice", "wheat", "vegetable", "vegetables", "fruit", "fruits", "tree", "trees", | |
| "cultivate", "cultivation", "plow", "plowing", "sow", "sowing", "reap", "reaping", "season", "seasons" | |
| ] | |
| def detect_language_keywords(text: str) -> str: | |
| text_lower = text.lower().strip() | |
| if not text_lower: | |
| return "en" | |
| hausa_count = sum(1 for word in HAUSA_WORDS if word in text_lower) | |
| yoruba_count = sum(1 for word in YORUBA_WORDS if word in text_lower) | |
| igbo_count = sum(1 for word in IGBO_WORDS if word in text_lower) | |
| english_count = sum(1 for word in ENGLISH_WORDS if word in text_lower) | |
| logger.info(f"Language detection scores - Hausa: {hausa_count}, Yoruba: {yoruba_count}, Igbo: {igbo_count}, English: {english_count}") | |
| if hausa_count > yoruba_count and hausa_count > igbo_count and hausa_count > english_count: | |
| logger.info("Keyword detection: Hausa") | |
| return "ha" | |
| elif yoruba_count > hausa_count and yoruba_count > igbo_count and yoruba_count > english_count: | |
| logger.info("Keyword detection: Yoruba") | |
| return "yo" | |
| elif igbo_count > hausa_count and igbo_count > yoruba_count and igbo_count > english_count: | |
| logger.info("Keyword detection: Igbo") | |
| return "ig" | |
| else: | |
| logger.info("Keyword detection: English (default)") | |
| return "en" | |
| def detect_language(text: str) -> str: | |
| logger.info(f"Detecting language for text: '{text[:50]}...'") | |
| return detect_language_keywords(text) | |
| def text_to_speech_file(text: str) -> str: | |
| lang = detect_language(text) | |
| print(f"Detected language: {lang}") | |
| if lang == "ig": | |
| logger.info("Igbo language detected - returning text response instead of audio") | |
| fd, path = tempfile.mkstemp(suffix=".txt") | |
| os.close(fd) | |
| with open(path, 'w', encoding='utf-8') as f: | |
| f.write(text) | |
| return path | |
| if lang == "ha": | |
| tts_model = tts_ha | |
| elif lang == "yo": | |
| tts_model = tts_yo | |
| else: | |
| tts_model = tts_en | |
| if tts_model is None: | |
| logger.error(f"TTS model for {lang} is not available") | |
| raise HTTPException(status_code=500, detail=f"TTS model for {lang} is not available") | |
| speech_output = tts_model(text) | |
| audio_raw = speech_output["audio"] | |
| sampling_rate = int(speech_output["sampling_rate"]) | |
| if isinstance(audio_raw, torch.Tensor): | |
| audio_np = audio_raw.detach().cpu().numpy() | |
| else: | |
| audio_np = np.asarray(audio_raw) | |
| if audio_np.ndim > 1: | |
| audio_np = audio_np.reshape(-1) | |
| audio_np = audio_np.astype(np.float32, copy=False) | |
| target_sr = 16000 | |
| if sampling_rate != target_sr: | |
| logger.info(f"Resampling TTS audio from {sampling_rate}Hz to {target_sr}Hz") | |
| ratio = target_sr / sampling_rate | |
| new_length = int(len(audio_np) * ratio) | |
| audio_np = np.interp( | |
| np.linspace(0, len(audio_np), new_length), | |
| np.arange(len(audio_np)), | |
| audio_np | |
| ) | |
| sampling_rate = target_sr | |
| audio_clipped = np.clip(audio_np, -1.0, 1.0) | |
| audio_int16 = (audio_clipped * 32767.0).astype(np.int16) | |
| fd, path = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16') | |
| return path | |
| async def root(): | |
| return {"status": "ok", "message": "System ready"} | |
| async def health(): | |
| return { | |
| "message": "Farmlingua AI Speech Interface is running!", | |
| "language_detection": "keyword-based (lightweight)", | |
| "tts_models": { | |
| "hausa": tts_ha is not None, | |
| "english": tts_en is not None, | |
| "yoruba": tts_yo is not None, | |
| "igbo": False | |
| } | |
| } | |
| async def status(): | |
| return { | |
| "language_detection": "keyword-based (lightweight)", | |
| "status": "ready", | |
| "message": "Using lightweight keyword-based language detection - no heavy models required" | |
| } | |
| async def get_languages(): | |
| return { | |
| "supported_languages": { | |
| "yo": { | |
| "name": "Yoruba", | |
| "code": "yo", | |
| "tts_available": True, | |
| "asr_model": "NCAIR1/Yoruba-ASR" | |
| }, | |
| "ha": { | |
| "name": "Hausa", | |
| "code": "ha", | |
| "tts_available": True, | |
| "asr_model": "NCAIR1/Hausa-ASR" | |
| }, | |
| "ig": { | |
| "name": "Igbo", | |
| "code": "ig", | |
| "tts_available": False, | |
| "asr_model": "NCAIR1/Igbo-ASR", | |
| "note": "Text response only - TTS not available" | |
| }, | |
| "en": { | |
| "name": "English", | |
| "code": "en", | |
| "tts_available": True, | |
| "asr_model": "NCAIR1/NigerianAccentedEnglish" | |
| } | |
| }, | |
| "usage": { | |
| "specify_language": "Use /speak endpoint with language parameter for best accuracy", | |
| "auto_detection": "Use /speak-auto endpoint when language is unknown", | |
| "igbo_response": "Igbo responses return text only (no audio TTS available)" | |
| } | |
| } | |
| async def chat(text: str = Form(...), speak: bool = False, raw: bool = False): | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| final_text = text if raw else get_ai_response(text) | |
| if speak: | |
| output_path = text_to_speech_file(final_text) | |
| lang = detect_language(final_text) | |
| if lang == "ig": | |
| return FileResponse(output_path, media_type="text/plain", filename="response.txt") | |
| else: | |
| return FileResponse(output_path, media_type="audio/wav", filename="response.wav") | |
| return {"question": text, "answer": final_text} | |
| async def speak_to_ai(audio_file: UploadFile = File(...), language: str = Form(...), speak: bool = True): | |
| if not audio_file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="File must be an audio file") | |
| if language not in ["yo", "ha", "ig", "en"]: | |
| raise HTTPException(status_code=400, detail="Language must be one of: yo (Yoruba), ha (Hausa), ig (Igbo), en (English)") | |
| audio_data = await audio_file.read() | |
| transcription = speech_to_text_with_language(audio_data, language) | |
| ai_response = get_ai_response(transcription) | |
| if speak: | |
| if language == "ig": | |
| return {"transcription": transcription, "ai_response": ai_response, "language": language, "response_type": "text", "note": "Igbo TTS not available - returning text response"} | |
| else: | |
| output_path = text_to_speech_file(ai_response) | |
| return FileResponse(output_path, media_type="audio/wav", filename="response.wav") | |
| return {"transcription": transcription, "ai_response": ai_response, "language": language} | |
| async def speak_to_ai_auto(audio_file: UploadFile = File(...), speak: bool = True): | |
| if not audio_file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="File must be an audio file") | |
| audio_data = await audio_file.read() | |
| transcription = speech_to_text(audio_data) | |
| ai_response = get_ai_response(transcription) | |
| if speak: | |
| detected_lang = detect_language(ai_response) | |
| if detected_lang == "ig": | |
| return {"transcription": transcription, "ai_response": ai_response, "detected_language": detected_lang, "response_type": "text", "note": "Igbo TTS not available - returning text response"} | |
| else: | |
| output_path = text_to_speech_file(ai_response) | |
| return FileResponse(output_path, media_type="audio/wav", filename="response.wav") | |
| return {"transcription": transcription, "ai_response": ai_response} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860"))) |