Spaces:
Sleeping
Sleeping
File size: 3,435 Bytes
1caa8b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer
import torch
import soundfile as sf
import io
import base64
import os
app = FastAPI(
title="MMS-TTS Vietnamese API",
description="A simple API for Vietnamese Text-to-Speech using facebook/mms-tts-vie."
)
# Define the request body model
class TTSRequest(BaseModel):
text: str
speaker_id: int = 0 # MMS-TTS models are single-speaker by default, but keeping this for potential future multi-speaker models
speed_factor: float = 1.0 # Optional: Adjust speech speed
# Global variables to hold the loaded model and tokenizer
# Avoid reloading them for every request, improving performance.
model = None
tokenizer = None
@app.on_event("startup")
async def startup_event():
"""
Load the TTS model and tokenizer when the FastAPI application starts up.
This ensures they are ready for immediate use and not reloaded per request.
"""
global model, tokenizer
try:
print("Loading MMS-TTS model 'facebook/mms-tts-vie'...")
model = VitsModel.from_pretrained("facebook/mms-tts-vie")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie")
print("MMS-TTS model and tokenizer loaded successfully.")
except Exception as e:
# Log the full exception for debugging
import traceback
traceback.print_exc()
print(f"ERROR: Failed to load MMS-TTS model or tokenizer on startup: {e}")
model = None # Ensure they are clearly unset if loading fails
tokenizer = None
@app.get("/")
async def read_root():
"""Basic health check endpoint."""
return {"message": "MMS-TTS Vietnamese API is running!", "model_loaded": model is not None}
@app.post("/synthesize_speech")
async def synthesize_speech(request: TTSRequest):
"""
Synthesizes speech from the given text and returns it as a Base64 encoded WAV byte array.
"""
if model is None or tokenizer is None:
raise HTTPException(
status_code=503,
detail="TTS model is not loaded. Please try again later or check server logs."
)
try:
# Tokenize the input text
# Ensure the text is properly handled by the tokenizer for Vietnamese
inputs = tokenizer(request.text, return_tensors="pt")
# Generate speech
with torch.no_grad():
# speaker_id is usually not needed for single-speaker models like mms-tts-vie
# if model supports it, you might pass speaker_id=request.speaker_id
audio_values = model(**inputs).waveform
# Convert to WAV bytes in memory
# The sampling rate is critical and should match the model's config
samplerate = model.config.sampling_rate
output_buffer = io.BytesIO()
sf.write(output_buffer, audio_values.numpy().squeeze(), samplerate, format='WAV')
output_buffer.seek(0) # Rewind the buffer to the beginning
# Encode the WAV bytes to Base64 string
audio_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')
return {"audio_base64": audio_base64}
except Exception as e:
import traceback
traceback.print_exc() # Print full traceback to console for debugging
raise HTTPException(
status_code=500,
detail=f"An error occurred during speech synthesis: {e}"
) |