from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import VitsModel, AutoTokenizer import torch import soundfile as sf import io import base64 import os app = FastAPI( title="MMS-TTS Vietnamese API", description="A simple API for Vietnamese Text-to-Speech using facebook/mms-tts-vie." ) # Define the request body model class TTSRequest(BaseModel): text: str speaker_id: int = 0 # MMS-TTS models are single-speaker by default, but keeping this for potential future multi-speaker models speed_factor: float = 1.0 # Optional: Adjust speech speed # Global variables to hold the loaded model and tokenizer # Avoid reloading them for every request, improving performance. model = None tokenizer = None @app.on_event("startup") async def startup_event(): """ Load the TTS model and tokenizer when the FastAPI application starts up. This ensures they are ready for immediate use and not reloaded per request. """ global model, tokenizer try: print("Loading MMS-TTS model 'facebook/mms-tts-vie'...") model = VitsModel.from_pretrained("facebook/mms-tts-vie") tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie") print("MMS-TTS model and tokenizer loaded successfully.") except Exception as e: # Log the full exception for debugging import traceback traceback.print_exc() print(f"ERROR: Failed to load MMS-TTS model or tokenizer on startup: {e}") model = None # Ensure they are clearly unset if loading fails tokenizer = None @app.get("/") async def read_root(): """Basic health check endpoint.""" return {"message": "MMS-TTS Vietnamese API is running!", "model_loaded": model is not None} @app.post("/synthesize_speech") async def synthesize_speech(request: TTSRequest): """ Synthesizes speech from the given text and returns it as a Base64 encoded WAV byte array. """ if model is None or tokenizer is None: raise HTTPException( status_code=503, detail="TTS model is not loaded. Please try again later or check server logs." ) try: # Tokenize the input text # Ensure the text is properly handled by the tokenizer for Vietnamese inputs = tokenizer(request.text, return_tensors="pt") # Generate speech with torch.no_grad(): # speaker_id is usually not needed for single-speaker models like mms-tts-vie # if model supports it, you might pass speaker_id=request.speaker_id audio_values = model(**inputs).waveform # Convert to WAV bytes in memory # The sampling rate is critical and should match the model's config samplerate = model.config.sampling_rate output_buffer = io.BytesIO() sf.write(output_buffer, audio_values.numpy().squeeze(), samplerate, format='WAV') output_buffer.seek(0) # Rewind the buffer to the beginning # Encode the WAV bytes to Base64 string audio_base64 = base64.b64encode(output_buffer.read()).decode('utf-8') return {"audio_base64": audio_base64} except Exception as e: import traceback traceback.print_exc() # Print full traceback to console for debugging raise HTTPException( status_code=500, detail=f"An error occurred during speech synthesis: {e}" )