milestone3 / app.py
nexusbert's picture
push
634bf12
import os
import io
import tempfile
import subprocess
import requests
import torch
import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import imageio_ffmpeg
import logging
from contextlib import asynccontextmanager
import uvicorn
import nest_asyncio
nest_asyncio.apply()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
load_models()
yield
app = FastAPI(title="Farmlingua AI Speech Interface", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ASK_URL = "https://remostart-milestone-one-farmlingua-ai.hf.space/ask"
tts_ha, tts_en, tts_yo, tts_ig = None, None, None, None
asr_models = {
"ha": {"repo": "NCAIR1/Hausa-ASR", "model": None, "proc": None},
"yo": {"repo": "NCAIR1/Yoruba-ASR", "model": None, "proc": None},
"ig": {"repo": "NCAIR1/Igbo-ASR", "model": None, "proc": None},
"en": {"repo": "NCAIR1/NigerianAccentedEnglish", "model": None, "proc": None},
}
def load_models():
global tts_ha, tts_en, tts_yo, tts_ig
device = 0 if torch.cuda.is_available() else -1
hf_token = os.getenv("HF_TOKEN")
if hf_token:
hf_token = hf_token.strip()
if not hf_token:
logger.warning("HF_TOKEN not set! This may cause authentication failures for gated repositories.")
logger.warning("Please set HF_TOKEN environment variable to access restricted models.")
else:
logger.info("HF_TOKEN is set and ready for authenticated model access.")
logger.info("Using lightweight keyword-based language detection (no heavy models)")
logger.info("Loading TTS models...")
try:
tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device)
logger.info("Loaded TTS (Hausa)")
except Exception as e:
logger.exception("Failed to load TTS (Hausa)")
tts_ha = None
try:
tts_en = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device)
logger.info("Loaded TTS (English)")
except Exception:
logger.exception("Failed to load TTS (English)")
tts_en = None
try:
tts_yo = pipeline("text-to-speech", model="facebook/mms-tts-yor", device=device)
logger.info("Loaded TTS (Yoruba)")
except Exception:
logger.exception("Failed to load TTS (Yoruba)")
tts_yo = None
tts_ig = None
logger.info("Igbo TTS model disabled - will return text responses for Igbo language")
logger.info("Deferred ASR model loads: will lazy-load per language on first use")
def _get_asr(lang_code: str):
entry = asr_models.get(lang_code)
if not entry:
return None, None
if entry["model"] is not None and entry["proc"] is not None:
return entry["model"], entry["proc"]
repo_id = entry["repo"]
hf_token = os.getenv("HF_TOKEN")
if hf_token:
hf_token = hf_token.strip()
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Lazy-loading ASR for {lang_code} from {repo_id}...")
proc = WhisperProcessor.from_pretrained(repo_id, token=hf_token)
model = WhisperForConditionalGeneration.from_pretrained(repo_id, token=hf_token)
model.to(device)
model.eval()
entry["model"], entry["proc"] = model, proc
logger.info(f"Loaded ASR for {lang_code}")
return model, proc
except Exception:
logger.exception(f"Failed to load ASR for {lang_code} ({repo_id})")
entry["model"], entry["proc"] = None, None
return None, None
def _run_whisper(model: WhisperForConditionalGeneration, proc: WhisperProcessor, audio_array: np.ndarray) -> str:
try:
device = next(model.parameters()).device
inputs = proc(audio_array, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.to(device)
with torch.no_grad():
predicted_ids = model.generate(input_features)
text_list = proc.batch_decode(predicted_ids, skip_special_tokens=True)
return text_list[0] if text_list else ""
except Exception:
logging.exception("Whisper ASR inference failed")
return ""
def preprocess_audio_ffmpeg(audio_data: bytes, target_sr: int = 16000) -> np.ndarray:
try:
with tempfile.NamedTemporaryFile(suffix='.input', delete=False) as in_file:
in_file.write(audio_data)
in_path = in_file.name
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out_file:
out_path = out_file.name
ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
subprocess.run([
ffmpeg_exe, '-y', '-i', in_path,
'-ac', '1',
'-ar', str(target_sr),
'-af',
'highpass=f=80,' +
'lowpass=f=8000,' +
'dynaudnorm=p=0.95:m=10.0,' +
'volume=1.0,' +
'aresample=resampler=soxr',
out_path
], check=True, capture_output=True)
with open(out_path, 'rb') as f:
wav_data = f.read()
os.unlink(in_path)
os.unlink(out_path)
audio_array, sr = sf.read(io.BytesIO(wav_data))
if len(audio_array.shape) > 1:
audio_array = np.mean(audio_array, axis=1)
if sr != target_sr:
logger.warning(f"Audio sampling rate {sr} != target {target_sr}, applying additional resampling...")
try:
from scipy import signal
ratio = target_sr / sr
audio_array = signal.resample(audio_array, int(len(audio_array) * ratio))
logger.info(f"Successfully resampled using scipy to {target_sr}Hz")
except ImportError:
logger.warning("scipy not available, using numpy interpolation")
ratio = target_sr / sr
new_length = int(len(audio_array) * ratio)
audio_array = np.interp(
np.linspace(0, len(audio_array), new_length),
np.arange(len(audio_array)),
audio_array
)
audio_array = _validate_and_normalize_audio(audio_array)
logger.info(f"Audio preprocessing complete: {len(audio_array)} samples at {target_sr}Hz")
return audio_array.astype(np.float32)
except Exception as e:
logger.error(f"FFmpeg preprocessing failed: {e}")
raise HTTPException(status_code=400, detail="Audio preprocessing failed. Ensure ffmpeg is installed.")
def _validate_and_normalize_audio(audio_array: np.ndarray) -> np.ndarray:
rms = np.sqrt(np.mean(audio_array**2))
if rms < 0.001:
logger.warning("Audio appears to be very quiet or silent")
max_val = np.max(np.abs(audio_array))
if max_val > 0.95:
logger.warning(f"Audio may be clipped (max: {max_val:.3f})")
target_rms = 0.1
current_rms = np.sqrt(np.mean(audio_array**2))
if current_rms > 0:
normalization_factor = min(target_rms / current_rms, 2.0)
audio_array = audio_array * normalization_factor
logger.info(f"Normalized audio RMS from {current_rms:.4f} to {np.sqrt(np.mean(audio_array**2)):.4f}")
audio_array = np.clip(audio_array, -0.99, 0.99)
audio_array = audio_array - np.mean(audio_array)
return audio_array
def chunk_audio(audio_array: np.ndarray, chunk_length: float = 10.0, overlap: float = 1.0, sample_rate: int = 16000) -> list:
chunk_samples = int(chunk_length * sample_rate)
overlap_samples = int(overlap * sample_rate)
step_samples = chunk_samples - overlap_samples
chunks = []
start = 0
while start < len(audio_array):
end = min(start + chunk_samples, len(audio_array))
chunk = audio_array[start:end]
fade_samples = int(0.05 * sample_rate)
if len(chunk) > 2 * fade_samples:
chunk[:fade_samples] *= np.linspace(0, 1, fade_samples)
chunk[-fade_samples:] *= np.linspace(1, 0, fade_samples)
if len(chunk) < chunk_samples:
chunk = np.pad(chunk, (0, chunk_samples - len(chunk)), mode='constant')
chunk_rms = np.sqrt(np.mean(chunk**2))
if chunk_rms < 0.001:
logger.warning(f"Chunk {len(chunks)+1} appears to be very quiet (RMS: {chunk_rms:.6f})")
chunks.append(chunk)
start += step_samples
if end >= len(audio_array):
break
logger.info(f"Split audio into {len(chunks)} chunks of {chunk_length}s each with quality preservation")
return chunks
def speech_to_text_with_language(audio_data: bytes, language: str) -> str:
audio_array = preprocess_audio_ffmpeg(audio_data)
audio_duration = len(audio_array) / 16000
logger.info(f"Audio duration: {audio_duration:.2f} seconds, processing with {language} model")
model, proc = _get_asr(language)
if model is None or proc is None:
logger.error(f"Failed to load {language} ASR model")
return ""
if audio_duration <= 15:
return _process_single_chunk_with_language(audio_array, model, proc, language)
else:
return _process_chunked_audio_with_language(audio_array, model, proc, language)
def _process_single_chunk_with_language(audio_array: np.ndarray, model, proc, language: str) -> str:
text = _run_whisper(model, proc, audio_array)
if text and text.strip():
logger.info(f"Transcription ({language}): {text[:100]}...")
return text.strip()
return ""
def _process_chunked_audio_with_language(audio_array: np.ndarray, model, proc, language: str) -> str:
chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0)
chunk_texts = []
for i, chunk in enumerate(chunks):
try:
text = _run_whisper(model, proc, chunk)
if text and text.strip():
chunk_texts.append(text.strip())
logger.info(f"Chunk {i+1}/{len(chunks)} ({language}): {text[:50]}...")
except Exception as e:
logger.warning(f"Failed to process chunk {i+1} with {language}: {e}")
continue
if chunk_texts:
combined_text = " ".join(chunk_texts)
logger.info(f"Combined {language} result: {combined_text[:100]}...")
return combined_text
return ""
def speech_to_text(audio_data: bytes) -> str:
audio_array = preprocess_audio_ffmpeg(audio_data)
audio_duration = len(audio_array) / 16000
logger.info(f"Audio duration: {audio_duration:.2f} seconds")
if audio_duration <= 15:
return _process_single_chunk(audio_array)
else:
return _process_chunked_audio(audio_array)
def _process_single_chunk(audio_array: np.ndarray) -> str:
candidates = []
for code in ["yo", "ha", "ig", "en"]:
model, proc = _get_asr(code)
if model is None or proc is None:
continue
text = _run_whisper(model, proc, audio_array)
if text and text.strip():
candidates.append((code, text.strip()))
if not candidates:
return ""
best_transcription = _select_best_transcription(candidates)
return best_transcription
def _process_chunked_audio(audio_array: np.ndarray) -> str:
chunks = chunk_audio(audio_array, chunk_length=10.0, overlap=1.0)
language_results = {}
for code in ["yo", "ha", "ig", "en"]:
model, proc = _get_asr(code)
if model is None or proc is None:
continue
chunk_texts = []
for i, chunk in enumerate(chunks):
try:
text = _run_whisper(model, proc, chunk)
if text and text.strip():
chunk_texts.append(text.strip())
logger.info(f"Chunk {i+1}/{len(chunks)} ({code}): {text[:50]}...")
except Exception as e:
logger.warning(f"Failed to process chunk {i+1} with {code}: {e}")
continue
if chunk_texts:
combined_text = " ".join(chunk_texts)
language_results[code] = combined_text
logger.info(f"Combined {code} result: {combined_text[:100]}...")
if not language_results:
return ""
best_transcription = _select_best_transcription(list(language_results.items()))
return best_transcription
def _select_best_transcription(candidates: list) -> str:
if not candidates:
return ""
if len(candidates) == 1:
return candidates[0][1]
scored_candidates = []
for lang_code, text in candidates:
score = _score_transcription_quality(text, lang_code)
scored_candidates.append((score, lang_code, text))
logger.info(f"Transcription quality score for {lang_code}: {score:.2f} - '{text[:50]}...'")
scored_candidates.sort(key=lambda x: x[0], reverse=True)
best_score, best_lang, best_text = scored_candidates[0]
logger.info(f"Selected best transcription: {best_lang} (score: {best_score:.2f}) - '{best_text[:100]}...'")
return best_text
def _score_transcription_quality(text: str, lang_code: str) -> float:
if not text or not text.strip():
return 0.0
text = text.strip()
score = 0.0
word_count = len(text.split())
if word_count == 0:
return 0.0
score += min(word_count * 0.5, 10.0)
char_count = len(text)
score += min(char_count * 0.1, 5.0)
if lang_code == "yo":
yoruba_chars = sum(1 for c in text if c in "ẹọṣgb")
score += yoruba_chars * 2.0
elif lang_code == "ig":
igbo_chars = sum(1 for c in text if c in "ụọị")
score += igbo_chars * 2.0
elif lang_code == "ha":
hausa_chars = sum(1 for c in text if c in "ts")
score += hausa_chars * 1.5
punctuation_penalty = text.count('?') + text.count('!') + text.count('.')
score -= punctuation_penalty * 0.5
repeated_chars = sum(1 for i in range(len(text)-2) if text[i] == text[i+1] == text[i+2])
score -= repeated_chars * 1.0
return max(score, 0.0)
def get_ai_response(text: str) -> str:
try:
response = requests.post(ASK_URL, json={"query": text}, timeout=30)
response.raise_for_status()
result = response.json()
return result.get("answer", "Sorry, no answer returned.")
except Exception as e:
logger.error(f"AI request error: {e}")
return f"I'm sorry, I couldn't connect to the AI service. You said: '{text}'."
HAUSA_WORDS = [
"aikin", "manoma", "gona", "amfanin", "yanayi", "tsaba", "fasaha", "bisa", "noman", "shuka",
"daji", "rani", "damina", "amfani", "bidi'a", "noma", "bashi", "manure", "tsiro", "gishiri",
"gonaki", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar",
"tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan",
"gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai",
"kayan", "gonar", "tsirrai", "kayan", "gonar", "tsirrai", "kayan", "gonar"
]
YORUBA_WORDS = [
"ilé", "ọmọ", "òun", "awọn", "agbẹ", "oko", "ọgbà", "irugbin", "àkọsílẹ", "omi", "ojo", "àgbàlá", "irọlẹ",
"ni", "ti", "si", "fun", "lati", "ninu", "lori", "labe", "pelu", "ati", "tabi", "sugbon",
"o", "a", "e", "won", "mi", "re", "wa", "yin",
"kan", "kankan", "die", "pupo", "gbogbo", "kookan",
"nibi", "nibe", "igba", "akoko", "osu", "odun", "ise", "owo",
"láàsìbà", "dára", "jùlẹ̀", "ìwẹ̀", "ṣe", "kú", "tún", "fi", "wo",
"ẹ", "ọ", "ṣ", "gb", "gb", "gb", "gb", "gb", "gb", "gb", "gb",
"jẹ", "wá", "lọ", "dúró", "sọ", "gbọ", "rí", "mọ", "fẹ", "ní",
"pẹlu", "nitori", "tori", "nitori", "tori", "nitori", "tori", "nitori", "tori",
"ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn", "ṣugbọn"
]
IGBO_WORDS = [
"ugbo", "akụkọ", "mmiri", "ala", "ọrụ", "ncheta", "ọhụrụ", "ugwu", "nri", "ahụhụ",
"ọkụkọ", "ewu", "atụrụ", "ehi", "azụ", "osisi", "mkpụrụ", "ubi", "ọka", "ji",
"akwụkwọ", "ofe", "azu", "anụ", "nnu", "mmanụ", "ngwọ", "ọgwụ", "ahịhịa", "osisi",
"n'", "maka", "n'ihi", "n'ime", "n'elu", "n'okpuru", "ya", "anyị", "unu", "ha",
"otu", "ọtụtụ", "ebe", "oge", "ụ", "ọ", "ị", "bụ", "nọ", "ga", "dị", "ka", "ma"
]
ENGLISH_WORDS = [
"farm", "farming", "agriculture", "crop", "crops", "plant", "plants", "seed", "seeds", "soil",
"water", "rain", "weather", "harvest", "yield", "field", "fields", "farmer", "farmers", "grow",
"growing", "fertilizer", "pesticide", "irrigation", "livestock", "cattle", "chicken", "goat", "sheep",
"maize", "corn", "rice", "wheat", "vegetable", "vegetables", "fruit", "fruits", "tree", "trees",
"cultivate", "cultivation", "plow", "plowing", "sow", "sowing", "reap", "reaping", "season", "seasons"
]
def detect_language_keywords(text: str) -> str:
text_lower = text.lower().strip()
if not text_lower:
return "en"
hausa_count = sum(1 for word in HAUSA_WORDS if word in text_lower)
yoruba_count = sum(1 for word in YORUBA_WORDS if word in text_lower)
igbo_count = sum(1 for word in IGBO_WORDS if word in text_lower)
english_count = sum(1 for word in ENGLISH_WORDS if word in text_lower)
logger.info(f"Language detection scores - Hausa: {hausa_count}, Yoruba: {yoruba_count}, Igbo: {igbo_count}, English: {english_count}")
if hausa_count > yoruba_count and hausa_count > igbo_count and hausa_count > english_count:
logger.info("Keyword detection: Hausa")
return "ha"
elif yoruba_count > hausa_count and yoruba_count > igbo_count and yoruba_count > english_count:
logger.info("Keyword detection: Yoruba")
return "yo"
elif igbo_count > hausa_count and igbo_count > yoruba_count and igbo_count > english_count:
logger.info("Keyword detection: Igbo")
return "ig"
else:
logger.info("Keyword detection: English (default)")
return "en"
def detect_language(text: str) -> str:
logger.info(f"Detecting language for text: '{text[:50]}...'")
return detect_language_keywords(text)
def text_to_speech_file(text: str) -> str:
lang = detect_language(text)
print(f"Detected language: {lang}")
if lang == "ig":
logger.info("Igbo language detected - returning text response instead of audio")
fd, path = tempfile.mkstemp(suffix=".txt")
os.close(fd)
with open(path, 'w', encoding='utf-8') as f:
f.write(text)
return path
if lang == "ha":
tts_model = tts_ha
elif lang == "yo":
tts_model = tts_yo
else:
tts_model = tts_en
if tts_model is None:
logger.error(f"TTS model for {lang} is not available")
raise HTTPException(status_code=500, detail=f"TTS model for {lang} is not available")
speech_output = tts_model(text)
audio_raw = speech_output["audio"]
sampling_rate = int(speech_output["sampling_rate"])
if isinstance(audio_raw, torch.Tensor):
audio_np = audio_raw.detach().cpu().numpy()
else:
audio_np = np.asarray(audio_raw)
if audio_np.ndim > 1:
audio_np = audio_np.reshape(-1)
audio_np = audio_np.astype(np.float32, copy=False)
target_sr = 16000
if sampling_rate != target_sr:
logger.info(f"Resampling TTS audio from {sampling_rate}Hz to {target_sr}Hz")
ratio = target_sr / sampling_rate
new_length = int(len(audio_np) * ratio)
audio_np = np.interp(
np.linspace(0, len(audio_np), new_length),
np.arange(len(audio_np)),
audio_np
)
sampling_rate = target_sr
audio_clipped = np.clip(audio_np, -1.0, 1.0)
audio_int16 = (audio_clipped * 32767.0).astype(np.int16)
fd, path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(path, audio_int16, sampling_rate, format='WAV', subtype='PCM_16')
return path
@app.get("/")
async def root():
return {"status": "ok", "message": "System ready"}
@app.get("/health")
async def health():
return {
"message": "Farmlingua AI Speech Interface is running!",
"language_detection": "keyword-based (lightweight)",
"tts_models": {
"hausa": tts_ha is not None,
"english": tts_en is not None,
"yoruba": tts_yo is not None,
"igbo": False
}
}
@app.get("/status")
async def status():
return {
"language_detection": "keyword-based (lightweight)",
"status": "ready",
"message": "Using lightweight keyword-based language detection - no heavy models required"
}
@app.get("/languages")
async def get_languages():
return {
"supported_languages": {
"yo": {
"name": "Yoruba",
"code": "yo",
"tts_available": True,
"asr_model": "NCAIR1/Yoruba-ASR"
},
"ha": {
"name": "Hausa",
"code": "ha",
"tts_available": True,
"asr_model": "NCAIR1/Hausa-ASR"
},
"ig": {
"name": "Igbo",
"code": "ig",
"tts_available": False,
"asr_model": "NCAIR1/Igbo-ASR",
"note": "Text response only - TTS not available"
},
"en": {
"name": "English",
"code": "en",
"tts_available": True,
"asr_model": "NCAIR1/NigerianAccentedEnglish"
}
},
"usage": {
"specify_language": "Use /speak endpoint with language parameter for best accuracy",
"auto_detection": "Use /speak-auto endpoint when language is unknown",
"igbo_response": "Igbo responses return text only (no audio TTS available)"
}
}
@app.post("/chat")
async def chat(text: str = Form(...), speak: bool = False, raw: bool = False):
if not text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
final_text = text if raw else get_ai_response(text)
if speak:
output_path = text_to_speech_file(final_text)
lang = detect_language(final_text)
if lang == "ig":
return FileResponse(output_path, media_type="text/plain", filename="response.txt")
else:
return FileResponse(output_path, media_type="audio/wav", filename="response.wav")
return {"question": text, "answer": final_text}
@app.post("/speak")
async def speak_to_ai(audio_file: UploadFile = File(...), language: str = Form(...), speak: bool = True):
if not audio_file.content_type.startswith('audio/'):
raise HTTPException(status_code=400, detail="File must be an audio file")
if language not in ["yo", "ha", "ig", "en"]:
raise HTTPException(status_code=400, detail="Language must be one of: yo (Yoruba), ha (Hausa), ig (Igbo), en (English)")
audio_data = await audio_file.read()
transcription = speech_to_text_with_language(audio_data, language)
ai_response = get_ai_response(transcription)
if speak:
if language == "ig":
return {"transcription": transcription, "ai_response": ai_response, "language": language, "response_type": "text", "note": "Igbo TTS not available - returning text response"}
else:
output_path = text_to_speech_file(ai_response)
return FileResponse(output_path, media_type="audio/wav", filename="response.wav")
return {"transcription": transcription, "ai_response": ai_response, "language": language}
@app.post("/speak-auto")
async def speak_to_ai_auto(audio_file: UploadFile = File(...), speak: bool = True):
if not audio_file.content_type.startswith('audio/'):
raise HTTPException(status_code=400, detail="File must be an audio file")
audio_data = await audio_file.read()
transcription = speech_to_text(audio_data)
ai_response = get_ai_response(transcription)
if speak:
detected_lang = detect_language(ai_response)
if detected_lang == "ig":
return {"transcription": transcription, "ai_response": ai_response, "detected_language": detected_lang, "response_type": "text", "note": "Igbo TTS not available - returning text response"}
else:
output_path = text_to_speech_file(ai_response)
return FileResponse(output_path, media_type="audio/wav", filename="response.wav")
return {"transcription": transcription, "ai_response": ai_response}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))