bororo_tts / app.py
Alicehy's picture
use /tmp cache
f81c461
# app.py
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer
import torch, numpy as np, io
from scipy.io.wavfile import write
from fastapi.staticfiles import StaticFiles
import os, pathlib
CACHE = "/tmp/hf-cache"
pathlib.Path(CACHE).mkdir(parents=True, exist_ok=True)
os.environ["HF_HOME"] = os.environ["TRANSFORMERS_CACHE"] = os.environ["XDG_CACHE_HOME"] = CACHE
MODEL = VitsModel.from_pretrained("facebook/mms-tts-bor", cache_dir=CACHE).eval()
TOK = AutoTokenizer.from_pretrained("facebook/mms-tts-bor", cache_dir=CACHE)
app = FastAPI()
class Inp(BaseModel):
text: str
@app.post("/api/tts")
def tts(inp: Inp):
with torch.inference_mode():
audio = MODEL(**TOK(inp.text, return_tensors="pt")).waveform.squeeze().cpu().numpy()
audio = audio / (np.max(np.abs(audio)) + 1e-9)
buf = io.BytesIO(); write(buf, MODEL.config.sampling_rate, (audio*32767).astype(np.int16)); buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav")
app.mount("/", StaticFiles(directory="public", html=True), name="ui")