import io from pathlib import Path from functools import lru_cache import numpy as np import torch import soundfile as sf from fastapi import FastAPI, HTTPException from fastapi.responses import Response from transformers import ( VitsModel, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, ) from huggingface_hub import hf_hub_download # Supported models map (subset matching the Streamlit app) SUPPORTED = { "facebook/mms-tts-ara": {"engine": "vits"}, "wasmdashai/vits-ar-sa-A": {"engine": "vits"}, "MBZUAI/speecht5_tts_clartts_ar": {"engine": "speecht5"}, } app = FastAPI(title="Arabic TTS API") def _ensure_valid_tokens(token_batch: dict): seq_len = token_batch["input_ids"].shape[-1] if seq_len < 2: raise ValueError("Input produced no valid tokens – provide more Arabic text.") @lru_cache(maxsize=8) def _load_vits(repo_id: str, cache_dir: str): model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir) tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir) return model, tokenizer @lru_cache(maxsize=8) def _load_speecht5(repo_id: str, cache_dir: str): processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir) model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir) spk = _load_speecht5_speaker_embedding(cache_dir) return processor, model, vocoder, spk def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor: try: xvector_path = hf_hub_download( repo_id="Matthijs/cmu-arctic-xvectors", filename="validation/000000.xvector.npy", repo_type="dataset", cache_dir=cache_dir, ) arr = np.load(xvector_path) vec = torch.from_numpy(arr) if vec.ndim == 1: vec = vec.unsqueeze(0) return vec except Exception: return torch.zeros((1, 512), dtype=torch.float32) @app.post("/tts") async def tts(payload: dict): text = (payload.get("text") or "").strip() model_id = payload.get("model_id") sample_rate = int(payload.get("sample_rate", 16000)) cache_dir = str(Path(payload.get("cache_dir") or "models_cache").expanduser()) if not text: raise HTTPException(status_code=400, detail="'text' is required") if model_id not in SUPPORTED: raise HTTPException(status_code=400, detail=f"Unsupported or missing model_id: {model_id}") device = "cuda" if torch.cuda.is_available() else "cpu" try: engine = SUPPORTED[model_id]["engine"] if engine == "vits": model, tokenizer = _load_vits(model_id, cache_dir) model.to(device).eval() inputs = tokenizer(text, return_tensors="pt") _ensure_valid_tokens(inputs) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.inference_mode(): outputs = model(**inputs) waveform = outputs.waveform.squeeze(0).cpu().numpy() sr = getattr(model.config, "sampling_rate", sample_rate) elif engine == "speecht5": processor, model, vocoder, speaker = _load_speecht5(model_id, cache_dir) model.to(device); vocoder.to(device) inputs = processor(text=text, return_tensors="pt") _ensure_valid_tokens(inputs) input_ids = inputs["input_ids"].to(device) speaker_embedding = speaker.to(device) with torch.inference_mode(): speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder) waveform = speech.cpu().numpy() sr = getattr(model.config, "sampling_rate", 16000) else: raise HTTPException(status_code=400, detail="Engine not supported via API") wav_io = io.BytesIO() sf.write(wav_io, waveform, int(sr), format="WAV", closefd=False) wav_io.seek(0) return Response(content=wav_io.getvalue(), media_type="audio/wav") except ValueError as ve: raise HTTPException(status_code=400, detail=str(ve)) from ve except Exception as e: raise HTTPException(status_code=500, detail=f"TTS failed: {e}") from e