File size: 4,335 Bytes
e3ff092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import io
from pathlib import Path
from functools import lru_cache

import numpy as np
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from transformers import (
    VitsModel,
    AutoTokenizer,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
)
from huggingface_hub import hf_hub_download

# Supported models map (subset matching the Streamlit app)
SUPPORTED = {
    "facebook/mms-tts-ara": {"engine": "vits"},
    "wasmdashai/vits-ar-sa-A": {"engine": "vits"},
    "MBZUAI/speecht5_tts_clartts_ar": {"engine": "speecht5"},
}

app = FastAPI(title="Arabic TTS API")


def _ensure_valid_tokens(token_batch: dict):
    seq_len = token_batch["input_ids"].shape[-1]
    if seq_len < 2:
        raise ValueError("Input produced no valid tokens – provide more Arabic text.")


@lru_cache(maxsize=8)
def _load_vits(repo_id: str, cache_dir: str):
    model = VitsModel.from_pretrained(repo_id, cache_dir=cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir)
    return model, tokenizer


@lru_cache(maxsize=8)
def _load_speecht5(repo_id: str, cache_dir: str):
    processor = SpeechT5Processor.from_pretrained(repo_id, cache_dir=cache_dir)
    model = SpeechT5ForTextToSpeech.from_pretrained(repo_id, cache_dir=cache_dir)
    vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=cache_dir)
    spk = _load_speecht5_speaker_embedding(cache_dir)
    return processor, model, vocoder, spk


def _load_speecht5_speaker_embedding(cache_dir: str) -> torch.Tensor:
    try:
        xvector_path = hf_hub_download(
            repo_id="Matthijs/cmu-arctic-xvectors",
            filename="validation/000000.xvector.npy",
            repo_type="dataset",
            cache_dir=cache_dir,
        )
        arr = np.load(xvector_path)
        vec = torch.from_numpy(arr)
        if vec.ndim == 1:
            vec = vec.unsqueeze(0)
        return vec
    except Exception:
        return torch.zeros((1, 512), dtype=torch.float32)


@app.post("/tts")
async def tts(payload: dict):
    text = (payload.get("text") or "").strip()
    model_id = payload.get("model_id")
    sample_rate = int(payload.get("sample_rate", 16000))
    cache_dir = str(Path(payload.get("cache_dir") or "models_cache").expanduser())

    if not text:
        raise HTTPException(status_code=400, detail="'text' is required")
    if model_id not in SUPPORTED:
        raise HTTPException(status_code=400, detail=f"Unsupported or missing model_id: {model_id}")

    device = "cuda" if torch.cuda.is_available() else "cpu"

    try:
        engine = SUPPORTED[model_id]["engine"]
        if engine == "vits":
            model, tokenizer = _load_vits(model_id, cache_dir)
            model.to(device).eval()
            inputs = tokenizer(text, return_tensors="pt")
            _ensure_valid_tokens(inputs)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.inference_mode():
                outputs = model(**inputs)
                waveform = outputs.waveform.squeeze(0).cpu().numpy()
            sr = getattr(model.config, "sampling_rate", sample_rate)
        elif engine == "speecht5":
            processor, model, vocoder, speaker = _load_speecht5(model_id, cache_dir)
            model.to(device); vocoder.to(device)
            inputs = processor(text=text, return_tensors="pt")
            _ensure_valid_tokens(inputs)
            input_ids = inputs["input_ids"].to(device)
            speaker_embedding = speaker.to(device)
            with torch.inference_mode():
                speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
            waveform = speech.cpu().numpy()
            sr = getattr(model.config, "sampling_rate", 16000)
        else:
            raise HTTPException(status_code=400, detail="Engine not supported via API")

        wav_io = io.BytesIO()
        sf.write(wav_io, waveform, int(sr), format="WAV", closefd=False)
        wav_io.seek(0)
        return Response(content=wav_io.getvalue(), media_type="audio/wav")
    except ValueError as ve:
        raise HTTPException(status_code=400, detail=str(ve)) from ve
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"TTS failed: {e}") from e