Spaces:
Build error
Build error
| import io | |
| from pathlib import Path | |
| from functools import lru_cache | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import Response | |
| from transformers import ( | |
| VitsModel, | |
| AutoTokenizer, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # Supported models map (subset matching the Streamlit app) | |
| SUPPORTED = { | |
| "facebook/mms-tts-ara": {"engine": "vits"}, | |
| "wasmdashai/vits-ar-sa-A": {"engine": "vits"}, | |
| "MBZUAI/speecht5_tts_clartts_ar": {"engine": "speecht5"}, | |
| } | |
| app = FastAPI(title="Arabic TTS API") | |
| def _ensure_valid_tokens(token_batch: dict): | |
| seq_len = token_batch["input_ids"].shape[-1] | |
| if seq_len < 2: | |
| raise ValueError("Input produced no valid tokens – provide more Arabic text.") | |
| def _load_vits(repo_id: str, cache_dir: str): | |
| model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir) | |
| return model, tokenizer | |
| def _load_speecht5(repo_id: str, cache_dir: str): | |
| processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir) | |
| model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir) | |
| spk = _load_speecht5_speaker_embedding(cache_dir) | |
| return processor, model, vocoder, spk | |
| def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor: | |
| try: | |
| xvector_path = hf_hub_download( | |
| repo_id="Matthijs/cmu-arctic-xvectors", | |
| filename="validation/000000.xvector.npy", | |
| repo_type="dataset", | |
| cache_dir=cache_dir, | |
| ) | |
| arr = np.load(xvector_path) | |
| vec = torch.from_numpy(arr) | |
| if vec.ndim == 1: | |
| vec = vec.unsqueeze(0) | |
| return vec | |
| except Exception: | |
| return torch.zeros((1, 512), dtype=torch.float32) | |
| async def tts(payload: dict): | |
| text = (payload.get("text") or "").strip() | |
| model_id = payload.get("model_id") | |
| sample_rate = int(payload.get("sample_rate", 16000)) | |
| cache_dir = str(Path(payload.get("cache_dir") or "models_cache").expanduser()) | |
| if not text: | |
| raise HTTPException(status_code=400, detail="'text' is required") | |
| if model_id not in SUPPORTED: | |
| raise HTTPException(status_code=400, detail=f"Unsupported or missing model_id: {model_id}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| engine = SUPPORTED[model_id]["engine"] | |
| if engine == "vits": | |
| model, tokenizer = _load_vits(model_id, cache_dir) | |
| model.to(device).eval() | |
| inputs = tokenizer(text, return_tensors="pt") | |
| _ensure_valid_tokens(inputs) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| waveform = outputs.waveform.squeeze(0).cpu().numpy() | |
| sr = getattr(model.config, "sampling_rate", sample_rate) | |
| elif engine == "speecht5": | |
| processor, model, vocoder, speaker = _load_speecht5(model_id, cache_dir) | |
| model.to(device); vocoder.to(device) | |
| inputs = processor(text=text, return_tensors="pt") | |
| _ensure_valid_tokens(inputs) | |
| input_ids = inputs["input_ids"].to(device) | |
| speaker_embedding = speaker.to(device) | |
| with torch.inference_mode(): | |
| speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder) | |
| waveform = speech.cpu().numpy() | |
| sr = getattr(model.config, "sampling_rate", 16000) | |
| else: | |
| raise HTTPException(status_code=400, detail="Engine not supported via API") | |
| wav_io = io.BytesIO() | |
| sf.write(wav_io, waveform, int(sr), format="WAV", closefd=False) | |
| wav_io.seek(0) | |
| return Response(content=wav_io.getvalue(), media_type="audio/wav") | |
| except ValueError as ve: | |
| raise HTTPException(status_code=400, detail=str(ve)) from ve | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"TTS failed: {e}") from e | |