tts-api / app.py
rashmika8352's picture
Update app.py
9d7001a verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer
import torch, io, wave, numpy as np
app = FastAPI()
SPEAKERS = {
# Single-speaker models
"en_us_male": ("facebook/mms-tts-eng", None),
"en_uk_female": ("kakao-enterprise/vits-ljs", None),
"spanish": ("facebook/mms-tts-spa", None),
"french": ("facebook/mms-tts-fra", None),
"german": ("facebook/mms-tts-deu", None),
"italian": ("facebook/mms-tts-ita", None),
"portuguese": ("facebook/mms-tts-por", None),
"japanese": ("facebook/mms-tts-jpn", None),
"korean": ("facebook/mms-tts-kor", None),
"chinese": ("facebook/mms-tts-zho", None),
"arabic": ("facebook/mms-tts-ara", None),
"hindi": ("facebook/mms-tts-hin", None),
"russian": ("facebook/mms-tts-rus", None),
# Multi-speaker VCTK — numeric speaker_id (0–108), mixed UK/US/Scottish/Irish accents
"en_vctk_0": ("kakao-enterprise/vits-vctk", 0),
"en_vctk_1": ("kakao-enterprise/vits-vctk", 1),
"en_vctk_2": ("kakao-enterprise/vits-vctk", 2),
"en_vctk_3": ("kakao-enterprise/vits-vctk", 3),
"en_vctk_4": ("kakao-enterprise/vits-vctk", 4),
"en_vctk_5": ("kakao-enterprise/vits-vctk", 5),
"en_vctk_6": ("kakao-enterprise/vits-vctk", 6),
"en_vctk_7": ("kakao-enterprise/vits-vctk", 7),
"en_vctk_8": ("kakao-enterprise/vits-vctk", 8),
"en_vctk_9": ("kakao-enterprise/vits-vctk", 9),
"en_vctk_10": ("kakao-enterprise/vits-vctk", 10),
"en_vctk_11": ("kakao-enterprise/vits-vctk", 11),
"en_vctk_12": ("kakao-enterprise/vits-vctk", 12),
"en_vctk_13": ("kakao-enterprise/vits-vctk", 13),
"en_vctk_14": ("kakao-enterprise/vits-vctk", 14),
"en_vctk_15": ("kakao-enterprise/vits-vctk", 15),
"en_vctk_16": ("kakao-enterprise/vits-vctk", 16),
"en_vctk_17": ("kakao-enterprise/vits-vctk", 17),
"en_vctk_18": ("kakao-enterprise/vits-vctk", 18),
"en_vctk_19": ("kakao-enterprise/vits-vctk", 19),
}
_cache = {}
def get_model(model_id: str):
if model_id not in _cache:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = VitsModel.from_pretrained(model_id)
model.eval()
_cache[model_id] = (model, tokenizer)
return _cache[model_id]
class TTSRequest(BaseModel):
text: str
speaker: str = "en_us_male"
@app.get("/speakers")
def list_speakers():
return {
"speakers": list(SPEAKERS.keys()),
"tip": "Try en_vctk_0 through en_vctk_19 for different English accents"
}
@app.post("/tts")
def text_to_speech(req: TTSRequest):
if req.speaker not in SPEAKERS:
raise HTTPException(status_code=400, detail=f"Unknown speaker. Call /speakers for the list.")
if not req.text.strip():
raise HTTPException(status_code=400, detail="text must not be empty.")
model_id, speaker_id = SPEAKERS[req.speaker]
model, tokenizer = get_model(model_id)
inputs = tokenizer(req.text, return_tensors="pt")
with torch.no_grad():
if speaker_id is not None:
output = model(**inputs, speaker_id=torch.tensor([speaker_id])).waveform
else:
output = model(**inputs).waveform
audio = output.squeeze().cpu().numpy()
audio = np.clip(audio, -1.0, 1.0)
pcm = (audio * 32767).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(model.config.sampling_rate)
wf.writeframes(pcm.tobytes())
buf.seek(0)
return StreamingResponse(
buf,
media_type="audio/wav",
headers={"Content-Disposition": "inline; filename=tts.wav"},
)