"""Afrolingo TTS microservice — HuggingFace Spaces deployment. Hosts MMS-TTS VITS models in-memory and exposes a ``/synthesize`` endpoint. Designed to run on HuggingFace Spaces free tier (2 vCPU, 16 GB RAM) where there is ample headroom for multiple loaded models. The main Afrolingo backend on Render (512 MB) calls this service via HTTP, keeping torch/transformers out of the gateway's memory budget. """ from __future__ import annotations import io import logging import os from collections import OrderedDict from threading import Lock import numpy as np import torch from fastapi import FastAPI, HTTPException from fastapi.responses import Response from pydantic import BaseModel from scipy.io import wavfile from transformers import AutoTokenizer, VitsModel logger = logging.getLogger("tts-service") logging.basicConfig(level=logging.INFO) app = FastAPI(title="Afrolingo TTS Service", version="0.1.0") MAX_MODELS = int(os.getenv("MAX_MODELS", "5")) # ------------------------------------------------------------------ # Thread-safe LRU model cache # ------------------------------------------------------------------ class _ModelCache: """Keep up to *max_models* ``(VitsModel, AutoTokenizer)`` pairs in memory.""" def __init__(self, max_models: int) -> None: self._max = max_models self._cache: OrderedDict[str, tuple[VitsModel, AutoTokenizer]] = OrderedDict() self._lock = Lock() def get_or_load(self, checkpoint: str) -> tuple[VitsModel, AutoTokenizer]: with self._lock: if checkpoint in self._cache: self._cache.move_to_end(checkpoint) return self._cache[checkpoint] # Load outside lock (slow I/O) logger.info("Loading model %s ...", checkpoint) model = VitsModel.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint) model.eval() logger.info("Loaded model %s", checkpoint) with self._lock: # Double-check after re-acquiring lock if checkpoint in self._cache: self._cache.move_to_end(checkpoint) return self._cache[checkpoint] if len(self._cache) >= self._max: evicted_key, _ = self._cache.popitem(last=False) logger.info("Evicted model %s", evicted_key) self._cache[checkpoint] = (model, tokenizer) return (model, tokenizer) @property def size(self) -> int: return len(self._cache) _cache = _ModelCache(MAX_MODELS) # ------------------------------------------------------------------ # Routes # ------------------------------------------------------------------ class SynthesizeRequest(BaseModel): text: str checkpoint: str @app.post("/synthesize") async def synthesize(req: SynthesizeRequest) -> Response: """Synthesize speech and return WAV audio bytes.""" # Load / retrieve model try: model, tokenizer = _cache.get_or_load(req.checkpoint) except Exception as exc: logger.error("Model load failed for %s: %s", req.checkpoint, exc) raise HTTPException(status_code=503, detail=f"Model loading failed: {exc}") # Inference try: inputs = tokenizer(req.text, return_tensors="pt") with torch.no_grad(): output = model(**inputs) waveform = output.waveform[0].cpu().numpy() sample_rate: int = model.config.sampling_rate waveform = np.clip(waveform, -1.0, 1.0) waveform_int16 = (waveform * 32767).astype(np.int16) buf = io.BytesIO() wavfile.write(buf, sample_rate, waveform_int16) return Response(content=buf.getvalue(), media_type="audio/wav") except Exception as exc: logger.error("Inference failed for %s: %s", req.checkpoint, exc) raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") @app.get("/health") async def health() -> dict: return {"status": "healthy", "cached_models": _cache.size}