Spaces:
Build error
Build error
File size: 4,335 Bytes
e3ff092 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import io
from pathlib import Path
from functools import lru_cache
import numpy as np
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from transformers import (
VitsModel,
AutoTokenizer,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
)
from huggingface_hub import hf_hub_download
# Supported models map (subset matching the Streamlit app)
SUPPORTED = {
"facebook/mms-tts-ara": {"engine": "vits"},
"wasmdashai/vits-ar-sa-A": {"engine": "vits"},
"MBZUAI/speecht5_tts_clartts_ar": {"engine": "speecht5"},
}
app = FastAPI(title="Arabic TTS API")
def _ensure_valid_tokens(token_batch: dict):
seq_len = token_batch["input_ids"].shape[-1]
if seq_len < 2:
raise ValueError("Input produced no valid tokens – provide more Arabic text.")
@lru_cache(maxsize=8)
def _load_vits(repo_id: str, cache_dir: str):
model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir)
return model, tokenizer
@lru_cache(maxsize=8)
def _load_speecht5(repo_id: str, cache_dir: str):
processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir)
model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir)
spk = _load_speecht5_speaker_embedding(cache_dir)
return processor, model, vocoder, spk
def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor:
try:
xvector_path = hf_hub_download(
repo_id="Matthijs/cmu-arctic-xvectors",
filename="validation/000000.xvector.npy",
repo_type="dataset",
cache_dir=cache_dir,
)
arr = np.load(xvector_path)
vec = torch.from_numpy(arr)
if vec.ndim == 1:
vec = vec.unsqueeze(0)
return vec
except Exception:
return torch.zeros((1, 512), dtype=torch.float32)
@app.post("/tts")
async def tts(payload: dict):
text = (payload.get("text") or "").strip()
model_id = payload.get("model_id")
sample_rate = int(payload.get("sample_rate", 16000))
cache_dir = str(Path(payload.get("cache_dir") or "models_cache").expanduser())
if not text:
raise HTTPException(status_code=400, detail="'text' is required")
if model_id not in SUPPORTED:
raise HTTPException(status_code=400, detail=f"Unsupported or missing model_id: {model_id}")
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
engine = SUPPORTED[model_id]["engine"]
if engine == "vits":
model, tokenizer = _load_vits(model_id, cache_dir)
model.to(device).eval()
inputs = tokenizer(text, return_tensors="pt")
_ensure_valid_tokens(inputs)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = model(**inputs)
waveform = outputs.waveform.squeeze(0).cpu().numpy()
sr = getattr(model.config, "sampling_rate", sample_rate)
elif engine == "speecht5":
processor, model, vocoder, speaker = _load_speecht5(model_id, cache_dir)
model.to(device); vocoder.to(device)
inputs = processor(text=text, return_tensors="pt")
_ensure_valid_tokens(inputs)
input_ids = inputs["input_ids"].to(device)
speaker_embedding = speaker.to(device)
with torch.inference_mode():
speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
waveform = speech.cpu().numpy()
sr = getattr(model.config, "sampling_rate", 16000)
else:
raise HTTPException(status_code=400, detail="Engine not supported via API")
wav_io = io.BytesIO()
sf.write(wav_io, waveform, int(sr), format="WAV", closefd=False)
wav_io.seek(0)
return Response(content=wav_io.getvalue(), media_type="audio/wav")
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) from ve
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS failed: {e}") from e
|