| import os |
| import io |
| import time |
| import tempfile |
| import threading |
| import logging |
|
|
| from flask import Flask, request, jsonify, send_file, render_template |
| from flask_cors import CORS |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| _models = {} |
| _lock = threading.Lock() |
| _warmup_ok = False |
|
|
| |
| |
| _DATA_DIR = "/data" if os.path.isdir("/data") else "/tmp" |
| os.makedirs(f"{_DATA_DIR}/.huggingface", exist_ok=True) |
| os.makedirs(f"{_DATA_DIR}/audio_cache", exist_ok=True) |
|
|
| |
| os.environ.setdefault("HF_HOME", f"{_DATA_DIR}/.huggingface") |
| os.environ.setdefault("TRANSFORMERS_CACHE", f"{_DATA_DIR}/.huggingface/hub") |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", f"{_DATA_DIR}/.huggingface/hub") |
| os.environ.setdefault("WHISPER_CACHE", f"{_DATA_DIR}/.huggingface/whisper") |
|
|
|
|
| |
| def get_stt_model(): |
| with _lock: |
| if "stt" not in _models: |
| import whisper |
| log.info("Loading Whisper small…") |
| _models["stt"] = whisper.load_model( |
| "small", |
| download_root=os.environ["WHISPER_CACHE"], |
| ) |
| log.info("Whisper ready.") |
| return _models["stt"] |
|
|
|
|
| def get_translation_pipeline(): |
| with _lock: |
| if "nllb" not in _models: |
| from transformers import pipeline |
| log.info("Loading NLLB-200 distilled 600M…") |
| _models["nllb"] = pipeline( |
| "translation", |
| model="facebook/nllb-200-distilled-600M", |
| src_lang="eng_Latn", |
| tgt_lang="kin_Latn", |
| device=-1, |
| max_length=512, |
| ) |
| log.info("NLLB ready.") |
| return _models["nllb"] |
|
|
|
|
| def get_tts_engine(): |
| with _lock: |
| if "tts" not in _models: |
| try: |
| from TTS.api import TTS |
| log.info("Loading Coqui TTS (Kinyarwanda)…") |
| _models["tts"] = TTS("tts_models/rw/cv/vits") |
| log.info("TTS ready.") |
| except Exception as e: |
| log.warning(f"TTS unavailable: {e}. Audio playback disabled.") |
| _models["tts"] = None |
| return _models["tts"] |
|
|
|
|
| def warmup_models(): |
| """Load all models at startup so the first real request is fast.""" |
| global _warmup_ok |
| try: |
| get_stt_model() |
| get_translation_pipeline() |
| get_tts_engine() |
| _warmup_ok = True |
| log.info("All models warmed up ✓") |
| except Exception as e: |
| log.error(f"Warmup failed: {e}") |
|
|
|
|
| |
| threading.Thread(target=warmup_models, daemon=True).start() |
|
|
|
|
| |
| _audio_cache: dict[str, str] = {} |
| _AUDIO_DIR = f"{_DATA_DIR}/audio_cache" |
|
|
|
|
| |
| @app.route("/") |
| def index(): |
| return render_template("index.html") |
|
|
|
|
| @app.route("/api/health") |
| def health(): |
| return jsonify({ |
| "status": "ok", |
| "models_ready": _warmup_ok, |
| "storage": _DATA_DIR, |
| }) |
|
|
|
|
| @app.route("/api/status") |
| def status(): |
| """Lightweight endpoint the frontend polls to show warmup progress.""" |
| loaded = list(_models.keys()) |
| return jsonify({ |
| "ready": _warmup_ok, |
| "loaded": loaded, |
| "total": 3, |
| }) |
|
|
|
|
| @app.route("/api/translate", methods=["POST"]) |
| def translate(): |
| """ |
| POST multipart/form-data with field 'audio' (webm or wav blob). |
| Returns JSON { transcript, translation, audio_token?, timing } |
| """ |
| if not _warmup_ok: |
| |
| loaded = list(_models.keys()) |
| return jsonify({ |
| "error": "Models are still loading, please wait…", |
| "warming": True, |
| "loaded": len(loaded), |
| "total": 3, |
| }), 503 |
|
|
| if "audio" not in request.files: |
| return jsonify({"error": "No audio file provided."}), 400 |
|
|
| audio_file = request.files["audio"] |
| suffix = ".webm" if "webm" in (audio_file.content_type or "") else ".wav" |
|
|
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp") as tmp: |
| audio_file.save(tmp.name) |
| tmp_path = tmp.name |
|
|
| try: |
| t0 = time.time() |
|
|
| |
| stt = get_stt_model() |
| result = stt.transcribe(tmp_path, language="en", fp16=False) |
| transcript = result["text"].strip() |
| stt_ms = int((time.time() - t0) * 1000) |
|
|
| if not transcript: |
| return jsonify({"error": "Could not transcribe audio. Please speak clearly."}), 422 |
|
|
| |
| t1 = time.time() |
| pipe = get_translation_pipeline() |
| out = pipe(transcript) |
| translation = out[0]["translation_text"] |
| trans_ms = int((time.time() - t1) * 1000) |
|
|
| |
| t2 = time.time() |
| tts = get_tts_engine() |
| wav_path = None |
|
|
| if tts: |
| wav_path = os.path.join( |
| _AUDIO_DIR, f"out_{int(time.time()*1000)}.wav" |
| ) |
| tts.tts_to_file(text=translation, file_path=wav_path) |
|
|
| tts_ms = int((time.time() - t2) * 1000) |
|
|
| response = { |
| "transcript": transcript, |
| "translation": translation, |
| "timing": { |
| "stt_ms": stt_ms, |
| "trans_ms": trans_ms, |
| "tts_ms": tts_ms, |
| "total_ms": int((time.time() - t0) * 1000), |
| }, |
| } |
|
|
| if wav_path and os.path.exists(wav_path): |
| token = str(int(time.time() * 1000)) |
| _audio_cache[token] = wav_path |
| response["audio_token"] = token |
|
|
| |
| _prune_audio_cache() |
|
|
| return jsonify(response) |
|
|
| finally: |
| if os.path.exists(tmp_path): |
| os.unlink(tmp_path) |
|
|
|
|
| @app.route("/api/audio/<token>") |
| def serve_audio(token): |
| path = _audio_cache.get(token) |
| if not path or not os.path.exists(path): |
| return jsonify({"error": "Audio not found or expired."}), 404 |
| return send_file(path, mimetype="audio/wav") |
|
|
|
|
| def _prune_audio_cache(max_age_s: int = 600): |
| now = time.time() |
| to_delete = [t for t, p in _audio_cache.items() |
| if now - os.path.getmtime(p) > max_age_s] |
| for t in to_delete: |
| try: |
| os.unlink(_audio_cache.pop(t)) |
| except OSError: |
| pass |
|
|
|
|
| |
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 7860)) |
| app.run(host="0.0.0.0", port=port, debug=False) |
|
|