Afrolingo / app.py
Sammydynamo's picture
Initial TTS microservice deployment
c6fda30
"""Afrolingo TTS microservice — HuggingFace Spaces deployment.
Hosts MMS-TTS VITS models in-memory and exposes a ``/synthesize``
endpoint. Designed to run on HuggingFace Spaces free tier (2 vCPU,
16 GB RAM) where there is ample headroom for multiple loaded models.
The main Afrolingo backend on Render (512 MB) calls this service via
HTTP, keeping torch/transformers out of the gateway's memory budget.
"""
from __future__ import annotations
import io
import logging
import os
from collections import OrderedDict
from threading import Lock
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from scipy.io import wavfile
from transformers import AutoTokenizer, VitsModel
logger = logging.getLogger("tts-service")
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="Afrolingo TTS Service", version="0.1.0")
MAX_MODELS = int(os.getenv("MAX_MODELS", "5"))
# ------------------------------------------------------------------
# Thread-safe LRU model cache
# ------------------------------------------------------------------
class _ModelCache:
"""Keep up to *max_models* ``(VitsModel, AutoTokenizer)`` pairs in memory."""
def __init__(self, max_models: int) -> None:
self._max = max_models
self._cache: OrderedDict[str, tuple[VitsModel, AutoTokenizer]] = OrderedDict()
self._lock = Lock()
def get_or_load(self, checkpoint: str) -> tuple[VitsModel, AutoTokenizer]:
with self._lock:
if checkpoint in self._cache:
self._cache.move_to_end(checkpoint)
return self._cache[checkpoint]
# Load outside lock (slow I/O)
logger.info("Loading model %s ...", checkpoint)
model = VitsModel.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.eval()
logger.info("Loaded model %s", checkpoint)
with self._lock:
# Double-check after re-acquiring lock
if checkpoint in self._cache:
self._cache.move_to_end(checkpoint)
return self._cache[checkpoint]
if len(self._cache) >= self._max:
evicted_key, _ = self._cache.popitem(last=False)
logger.info("Evicted model %s", evicted_key)
self._cache[checkpoint] = (model, tokenizer)
return (model, tokenizer)
@property
def size(self) -> int:
return len(self._cache)
_cache = _ModelCache(MAX_MODELS)
# ------------------------------------------------------------------
# Routes
# ------------------------------------------------------------------
class SynthesizeRequest(BaseModel):
text: str
checkpoint: str
@app.post("/synthesize")
async def synthesize(req: SynthesizeRequest) -> Response:
"""Synthesize speech and return WAV audio bytes."""
# Load / retrieve model
try:
model, tokenizer = _cache.get_or_load(req.checkpoint)
except Exception as exc:
logger.error("Model load failed for %s: %s", req.checkpoint, exc)
raise HTTPException(status_code=503, detail=f"Model loading failed: {exc}")
# Inference
try:
inputs = tokenizer(req.text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
waveform = output.waveform[0].cpu().numpy()
sample_rate: int = model.config.sampling_rate
waveform = np.clip(waveform, -1.0, 1.0)
waveform_int16 = (waveform * 32767).astype(np.int16)
buf = io.BytesIO()
wavfile.write(buf, sample_rate, waveform_int16)
return Response(content=buf.getvalue(), media_type="audio/wav")
except Exception as exc:
logger.error("Inference failed for %s: %s", req.checkpoint, exc)
raise HTTPException(status_code=500, detail=f"Inference failed: {exc}")
@app.get("/health")
async def health() -> dict:
return {"status": "healthy", "cached_models": _cache.size}