from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import VitsModel, AutoTokenizer import torch, io, wave, numpy as np app = FastAPI() SPEAKERS = { # Single-speaker models "en_us_male": ("facebook/mms-tts-eng", None), "en_uk_female": ("kakao-enterprise/vits-ljs", None), "spanish": ("facebook/mms-tts-spa", None), "french": ("facebook/mms-tts-fra", None), "german": ("facebook/mms-tts-deu", None), "italian": ("facebook/mms-tts-ita", None), "portuguese": ("facebook/mms-tts-por", None), "japanese": ("facebook/mms-tts-jpn", None), "korean": ("facebook/mms-tts-kor", None), "chinese": ("facebook/mms-tts-zho", None), "arabic": ("facebook/mms-tts-ara", None), "hindi": ("facebook/mms-tts-hin", None), "russian": ("facebook/mms-tts-rus", None), # Multi-speaker VCTK — numeric speaker_id (0–108), mixed UK/US/Scottish/Irish accents "en_vctk_0": ("kakao-enterprise/vits-vctk", 0), "en_vctk_1": ("kakao-enterprise/vits-vctk", 1), "en_vctk_2": ("kakao-enterprise/vits-vctk", 2), "en_vctk_3": ("kakao-enterprise/vits-vctk", 3), "en_vctk_4": ("kakao-enterprise/vits-vctk", 4), "en_vctk_5": ("kakao-enterprise/vits-vctk", 5), "en_vctk_6": ("kakao-enterprise/vits-vctk", 6), "en_vctk_7": ("kakao-enterprise/vits-vctk", 7), "en_vctk_8": ("kakao-enterprise/vits-vctk", 8), "en_vctk_9": ("kakao-enterprise/vits-vctk", 9), "en_vctk_10": ("kakao-enterprise/vits-vctk", 10), "en_vctk_11": ("kakao-enterprise/vits-vctk", 11), "en_vctk_12": ("kakao-enterprise/vits-vctk", 12), "en_vctk_13": ("kakao-enterprise/vits-vctk", 13), "en_vctk_14": ("kakao-enterprise/vits-vctk", 14), "en_vctk_15": ("kakao-enterprise/vits-vctk", 15), "en_vctk_16": ("kakao-enterprise/vits-vctk", 16), "en_vctk_17": ("kakao-enterprise/vits-vctk", 17), "en_vctk_18": ("kakao-enterprise/vits-vctk", 18), "en_vctk_19": ("kakao-enterprise/vits-vctk", 19), } _cache = {} def get_model(model_id: str): if model_id not in _cache: tokenizer = AutoTokenizer.from_pretrained(model_id) model = VitsModel.from_pretrained(model_id) model.eval() _cache[model_id] = (model, tokenizer) return _cache[model_id] class TTSRequest(BaseModel): text: str speaker: str = "en_us_male" @app.get("/speakers") def list_speakers(): return { "speakers": list(SPEAKERS.keys()), "tip": "Try en_vctk_0 through en_vctk_19 for different English accents" } @app.post("/tts") def text_to_speech(req: TTSRequest): if req.speaker not in SPEAKERS: raise HTTPException(status_code=400, detail=f"Unknown speaker. Call /speakers for the list.") if not req.text.strip(): raise HTTPException(status_code=400, detail="text must not be empty.") model_id, speaker_id = SPEAKERS[req.speaker] model, tokenizer = get_model(model_id) inputs = tokenizer(req.text, return_tensors="pt") with torch.no_grad(): if speaker_id is not None: output = model(**inputs, speaker_id=torch.tensor([speaker_id])).waveform else: output = model(**inputs).waveform audio = output.squeeze().cpu().numpy() audio = np.clip(audio, -1.0, 1.0) pcm = (audio * 32767).astype(np.int16) buf = io.BytesIO() with wave.open(buf, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(model.config.sampling_rate) wf.writeframes(pcm.tobytes()) buf.seek(0) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": "inline; filename=tts.wav"}, )