import os import io import time import tempfile import threading import logging from flask import Flask, request, jsonify, send_file, render_template from flask_cors import CORS # ── Logging ─────────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) app = Flask(__name__) CORS(app) # ── Model registry ──────────────────────────────────────────────────────────── _models = {} _lock = threading.Lock() _warmup_ok = False # flips True once all models are loaded # ── Persistent cache dir (HF Spaces mounts /data when paid storage is enabled) # Falls back to /tmp if /data isn't available (free tier without storage) _DATA_DIR = "/data" if os.path.isdir("/data") else "/tmp" os.makedirs(f"{_DATA_DIR}/.huggingface", exist_ok=True) os.makedirs(f"{_DATA_DIR}/audio_cache", exist_ok=True) # Override cache env vars at runtime too (belt-and-suspenders) os.environ.setdefault("HF_HOME", f"{_DATA_DIR}/.huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", f"{_DATA_DIR}/.huggingface/hub") os.environ.setdefault("HUGGINGFACE_HUB_CACHE", f"{_DATA_DIR}/.huggingface/hub") os.environ.setdefault("WHISPER_CACHE", f"{_DATA_DIR}/.huggingface/whisper") # ── Model loaders (lazy + cached) ───────────────────────────────────────────── def get_stt_model(): with _lock: if "stt" not in _models: import whisper log.info("Loading Whisper small…") _models["stt"] = whisper.load_model( "small", download_root=os.environ["WHISPER_CACHE"], ) log.info("Whisper ready.") return _models["stt"] def get_translation_pipeline(): with _lock: if "nllb" not in _models: from transformers import pipeline log.info("Loading NLLB-200 distilled 600M…") _models["nllb"] = pipeline( "translation", model="facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="kin_Latn", device=-1, # CPU max_length=512, ) log.info("NLLB ready.") return _models["nllb"] def get_tts_engine(): with _lock: if "tts" not in _models: try: from TTS.api import TTS log.info("Loading Coqui TTS (Kinyarwanda)…") _models["tts"] = TTS("tts_models/rw/cv/vits") log.info("TTS ready.") except Exception as e: log.warning(f"TTS unavailable: {e}. Audio playback disabled.") _models["tts"] = None return _models["tts"] def warmup_models(): """Load all models at startup so the first real request is fast.""" global _warmup_ok try: get_stt_model() get_translation_pipeline() get_tts_engine() _warmup_ok = True log.info("All models warmed up ✓") except Exception as e: log.error(f"Warmup failed: {e}") # Start warmup in background so the server responds to /health immediately threading.Thread(target=warmup_models, daemon=True).start() # ── In-memory audio cache (token → filepath) ────────────────────────────────── _audio_cache: dict[str, str] = {} _AUDIO_DIR = f"{_DATA_DIR}/audio_cache" # ── Routes ───────────────────────────────────────────────────────────────────── @app.route("/") def index(): return render_template("index.html") @app.route("/api/health") def health(): return jsonify({ "status": "ok", "models_ready": _warmup_ok, "storage": _DATA_DIR, }) @app.route("/api/status") def status(): """Lightweight endpoint the frontend polls to show warmup progress.""" loaded = list(_models.keys()) return jsonify({ "ready": _warmup_ok, "loaded": loaded, "total": 3, }) @app.route("/api/translate", methods=["POST"]) def translate(): """ POST multipart/form-data with field 'audio' (webm or wav blob). Returns JSON { transcript, translation, audio_token?, timing } """ if not _warmup_ok: # Return 503 so the frontend can show "warming up" message loaded = list(_models.keys()) return jsonify({ "error": "Models are still loading, please wait…", "warming": True, "loaded": len(loaded), "total": 3, }), 503 if "audio" not in request.files: return jsonify({"error": "No audio file provided."}), 400 audio_file = request.files["audio"] suffix = ".webm" if "webm" in (audio_file.content_type or "") else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp") as tmp: audio_file.save(tmp.name) tmp_path = tmp.name try: t0 = time.time() # ── 1. STT ──────────────────────────────────────────────────────────── stt = get_stt_model() result = stt.transcribe(tmp_path, language="en", fp16=False) transcript = result["text"].strip() stt_ms = int((time.time() - t0) * 1000) if not transcript: return jsonify({"error": "Could not transcribe audio. Please speak clearly."}), 422 # ── 2. Translation ──────────────────────────────────────────────────── t1 = time.time() pipe = get_translation_pipeline() out = pipe(transcript) translation = out[0]["translation_text"] trans_ms = int((time.time() - t1) * 1000) # ── 3. TTS ──────────────────────────────────────────────────────────── t2 = time.time() tts = get_tts_engine() wav_path = None if tts: wav_path = os.path.join( _AUDIO_DIR, f"out_{int(time.time()*1000)}.wav" ) tts.tts_to_file(text=translation, file_path=wav_path) tts_ms = int((time.time() - t2) * 1000) response = { "transcript": transcript, "translation": translation, "timing": { "stt_ms": stt_ms, "trans_ms": trans_ms, "tts_ms": tts_ms, "total_ms": int((time.time() - t0) * 1000), }, } if wav_path and os.path.exists(wav_path): token = str(int(time.time() * 1000)) _audio_cache[token] = wav_path response["audio_token"] = token # Keep cache tidy — delete files older than 10 min _prune_audio_cache() return jsonify(response) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) @app.route("/api/audio/") def serve_audio(token): path = _audio_cache.get(token) if not path or not os.path.exists(path): return jsonify({"error": "Audio not found or expired."}), 404 return send_file(path, mimetype="audio/wav") def _prune_audio_cache(max_age_s: int = 600): now = time.time() to_delete = [t for t, p in _audio_cache.items() if now - os.path.getmtime(p) > max_age_s] for t in to_delete: try: os.unlink(_audio_cache.pop(t)) except OSError: pass # ── Dev entry point ─────────────────────────────────────────────────────────── if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)