TextToSpeed-Api / main.py
Osi30
Add application file
1caa8b9
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer
import torch
import soundfile as sf
import io
import base64
import os
app = FastAPI(
title="MMS-TTS Vietnamese API",
description="A simple API for Vietnamese Text-to-Speech using facebook/mms-tts-vie."
)
# Define the request body model
class TTSRequest(BaseModel):
text: str
speaker_id: int = 0 # MMS-TTS models are single-speaker by default, but keeping this for potential future multi-speaker models
speed_factor: float = 1.0 # Optional: Adjust speech speed
# Global variables to hold the loaded model and tokenizer
# Avoid reloading them for every request, improving performance.
model = None
tokenizer = None
@app.on_event("startup")
async def startup_event():
"""
Load the TTS model and tokenizer when the FastAPI application starts up.
This ensures they are ready for immediate use and not reloaded per request.
"""
global model, tokenizer
try:
print("Loading MMS-TTS model 'facebook/mms-tts-vie'...")
model = VitsModel.from_pretrained("facebook/mms-tts-vie")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie")
print("MMS-TTS model and tokenizer loaded successfully.")
except Exception as e:
# Log the full exception for debugging
import traceback
traceback.print_exc()
print(f"ERROR: Failed to load MMS-TTS model or tokenizer on startup: {e}")
model = None # Ensure they are clearly unset if loading fails
tokenizer = None
@app.get("/")
async def read_root():
"""Basic health check endpoint."""
return {"message": "MMS-TTS Vietnamese API is running!", "model_loaded": model is not None}
@app.post("/synthesize_speech")
async def synthesize_speech(request: TTSRequest):
"""
Synthesizes speech from the given text and returns it as a Base64 encoded WAV byte array.
"""
if model is None or tokenizer is None:
raise HTTPException(
status_code=503,
detail="TTS model is not loaded. Please try again later or check server logs."
)
try:
# Tokenize the input text
# Ensure the text is properly handled by the tokenizer for Vietnamese
inputs = tokenizer(request.text, return_tensors="pt")
# Generate speech
with torch.no_grad():
# speaker_id is usually not needed for single-speaker models like mms-tts-vie
# if model supports it, you might pass speaker_id=request.speaker_id
audio_values = model(**inputs).waveform
# Convert to WAV bytes in memory
# The sampling rate is critical and should match the model's config
samplerate = model.config.sampling_rate
output_buffer = io.BytesIO()
sf.write(output_buffer, audio_values.numpy().squeeze(), samplerate, format='WAV')
output_buffer.seek(0) # Rewind the buffer to the beginning
# Encode the WAV bytes to Base64 string
audio_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')
return {"audio_base64": audio_base64}
except Exception as e:
import traceback
traceback.print_exc() # Print full traceback to console for debugging
raise HTTPException(
status_code=500,
detail=f"An error occurred during speech synthesis: {e}"
)