Kenya / app.py
Walelign's picture
Upload 7 files
088fe97 verified
import os
import io
import time
import tempfile
import threading
import logging
from flask import Flask, request, jsonify, send_file, render_template
from flask_cors import CORS
# ── Logging ───────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# ── Model registry ────────────────────────────────────────────────────────────
_models = {}
_lock = threading.Lock()
_warmup_ok = False # flips True once all models are loaded
# ── Persistent cache dir (HF Spaces mounts /data when paid storage is enabled)
# Falls back to /tmp if /data isn't available (free tier without storage)
_DATA_DIR = "/data" if os.path.isdir("/data") else "/tmp"
os.makedirs(f"{_DATA_DIR}/.huggingface", exist_ok=True)
os.makedirs(f"{_DATA_DIR}/audio_cache", exist_ok=True)
# Override cache env vars at runtime too (belt-and-suspenders)
os.environ.setdefault("HF_HOME", f"{_DATA_DIR}/.huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", f"{_DATA_DIR}/.huggingface/hub")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", f"{_DATA_DIR}/.huggingface/hub")
os.environ.setdefault("WHISPER_CACHE", f"{_DATA_DIR}/.huggingface/whisper")
# ── Model loaders (lazy + cached) ─────────────────────────────────────────────
def get_stt_model():
with _lock:
if "stt" not in _models:
import whisper
log.info("Loading Whisper small…")
_models["stt"] = whisper.load_model(
"small",
download_root=os.environ["WHISPER_CACHE"],
)
log.info("Whisper ready.")
return _models["stt"]
def get_translation_pipeline():
with _lock:
if "nllb" not in _models:
from transformers import pipeline
log.info("Loading NLLB-200 distilled 600M…")
_models["nllb"] = pipeline(
"translation",
model="facebook/nllb-200-distilled-600M",
src_lang="eng_Latn",
tgt_lang="kin_Latn",
device=-1, # CPU
max_length=512,
)
log.info("NLLB ready.")
return _models["nllb"]
def get_tts_engine():
with _lock:
if "tts" not in _models:
try:
from TTS.api import TTS
log.info("Loading Coqui TTS (Kinyarwanda)…")
_models["tts"] = TTS("tts_models/rw/cv/vits")
log.info("TTS ready.")
except Exception as e:
log.warning(f"TTS unavailable: {e}. Audio playback disabled.")
_models["tts"] = None
return _models["tts"]
def warmup_models():
"""Load all models at startup so the first real request is fast."""
global _warmup_ok
try:
get_stt_model()
get_translation_pipeline()
get_tts_engine()
_warmup_ok = True
log.info("All models warmed up ✓")
except Exception as e:
log.error(f"Warmup failed: {e}")
# Start warmup in background so the server responds to /health immediately
threading.Thread(target=warmup_models, daemon=True).start()
# ── In-memory audio cache (token → filepath) ──────────────────────────────────
_audio_cache: dict[str, str] = {}
_AUDIO_DIR = f"{_DATA_DIR}/audio_cache"
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.route("/")
def index():
return render_template("index.html")
@app.route("/api/health")
def health():
return jsonify({
"status": "ok",
"models_ready": _warmup_ok,
"storage": _DATA_DIR,
})
@app.route("/api/status")
def status():
"""Lightweight endpoint the frontend polls to show warmup progress."""
loaded = list(_models.keys())
return jsonify({
"ready": _warmup_ok,
"loaded": loaded,
"total": 3,
})
@app.route("/api/translate", methods=["POST"])
def translate():
"""
POST multipart/form-data with field 'audio' (webm or wav blob).
Returns JSON { transcript, translation, audio_token?, timing }
"""
if not _warmup_ok:
# Return 503 so the frontend can show "warming up" message
loaded = list(_models.keys())
return jsonify({
"error": "Models are still loading, please wait…",
"warming": True,
"loaded": len(loaded),
"total": 3,
}), 503
if "audio" not in request.files:
return jsonify({"error": "No audio file provided."}), 400
audio_file = request.files["audio"]
suffix = ".webm" if "webm" in (audio_file.content_type or "") else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp") as tmp:
audio_file.save(tmp.name)
tmp_path = tmp.name
try:
t0 = time.time()
# ── 1. STT ────────────────────────────────────────────────────────────
stt = get_stt_model()
result = stt.transcribe(tmp_path, language="en", fp16=False)
transcript = result["text"].strip()
stt_ms = int((time.time() - t0) * 1000)
if not transcript:
return jsonify({"error": "Could not transcribe audio. Please speak clearly."}), 422
# ── 2. Translation ────────────────────────────────────────────────────
t1 = time.time()
pipe = get_translation_pipeline()
out = pipe(transcript)
translation = out[0]["translation_text"]
trans_ms = int((time.time() - t1) * 1000)
# ── 3. TTS ────────────────────────────────────────────────────────────
t2 = time.time()
tts = get_tts_engine()
wav_path = None
if tts:
wav_path = os.path.join(
_AUDIO_DIR, f"out_{int(time.time()*1000)}.wav"
)
tts.tts_to_file(text=translation, file_path=wav_path)
tts_ms = int((time.time() - t2) * 1000)
response = {
"transcript": transcript,
"translation": translation,
"timing": {
"stt_ms": stt_ms,
"trans_ms": trans_ms,
"tts_ms": tts_ms,
"total_ms": int((time.time() - t0) * 1000),
},
}
if wav_path and os.path.exists(wav_path):
token = str(int(time.time() * 1000))
_audio_cache[token] = wav_path
response["audio_token"] = token
# Keep cache tidy — delete files older than 10 min
_prune_audio_cache()
return jsonify(response)
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.route("/api/audio/<token>")
def serve_audio(token):
path = _audio_cache.get(token)
if not path or not os.path.exists(path):
return jsonify({"error": "Audio not found or expired."}), 404
return send_file(path, mimetype="audio/wav")
def _prune_audio_cache(max_age_s: int = 600):
now = time.time()
to_delete = [t for t, p in _audio_cache.items()
if now - os.path.getmtime(p) > max_age_s]
for t in to_delete:
try:
os.unlink(_audio_cache.pop(t))
except OSError:
pass
# ── Dev entry point ───────────────────────────────────────────────────────────
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)