Jilma.Mott / app.py
IslamVuSo's picture
Add TTS model and FastAPI app
cc21735
import torch
from transformers import VitsModel, AutoTokenizer
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import soundfile as sf
import io
import uvicorn
import re # <-- 1. Import 're' for regex
# --- 2. ADD THIS NORMALIZATION FUNCTION ---
def normalize_text(text: str) -> str:
# Remove common punctuation, convert to lowercase, and strip whitespace
text = re.sub(r"[.,?!\-]", "", text)
return text.lower().strip()
# --- END OF NEW FUNCTION ---
class TextIn(BaseModel):
text: str
print("Loading model and tokenizer...")
processor = AutoTokenizer.from_pretrained("facebook/mms-tts-che")
model = VitsModel.from_pretrained("facebook/mms-tts-che")
print("Model and tokenizer loaded.")
app = FastAPI()
@app.post("/generate-tts/")
async def generate_tts(data: TextIn):
try:
# 3. Normalize the text first
text = normalize_text(data.text)
if not text:
return {"error": "Input text is empty after normalization."}, 400
print(f"Processing normalized text: {text}")
inputs = processor(text, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
speech = output.waveform
print("Audio generated. Creating buffer...")
buffer = io.BytesIO()
speech_np = speech.cpu().numpy().squeeze()
if speech_np.ndim > 1:
speech_np = speech_np[0]
sf.write(buffer,
speech_np,
samplerate=model.config.sampling_rate,
format='wav')
buffer.seek(0)
print("Returning audio file.")
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
print(f"Error occurred: {e}")
return {"error": str(e)}, 500
@app.get("/")
def read_root():
return {"status": "TTS server is running"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)