Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import VitsModel, AutoTokenizer | |
| import torch | |
| import soundfile as sf | |
| import io | |
| import base64 | |
| import os | |
| app = FastAPI( | |
| title="MMS-TTS Vietnamese API", | |
| description="A simple API for Vietnamese Text-to-Speech using facebook/mms-tts-vie." | |
| ) | |
| # Define the request body model | |
| class TTSRequest(BaseModel): | |
| text: str | |
| speaker_id: int = 0 # MMS-TTS models are single-speaker by default, but keeping this for potential future multi-speaker models | |
| speed_factor: float = 1.0 # Optional: Adjust speech speed | |
| # Global variables to hold the loaded model and tokenizer | |
| # Avoid reloading them for every request, improving performance. | |
| model = None | |
| tokenizer = None | |
| async def startup_event(): | |
| """ | |
| Load the TTS model and tokenizer when the FastAPI application starts up. | |
| This ensures they are ready for immediate use and not reloaded per request. | |
| """ | |
| global model, tokenizer | |
| try: | |
| print("Loading MMS-TTS model 'facebook/mms-tts-vie'...") | |
| model = VitsModel.from_pretrained("facebook/mms-tts-vie") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie") | |
| print("MMS-TTS model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| # Log the full exception for debugging | |
| import traceback | |
| traceback.print_exc() | |
| print(f"ERROR: Failed to load MMS-TTS model or tokenizer on startup: {e}") | |
| model = None # Ensure they are clearly unset if loading fails | |
| tokenizer = None | |
| async def read_root(): | |
| """Basic health check endpoint.""" | |
| return {"message": "MMS-TTS Vietnamese API is running!", "model_loaded": model is not None} | |
| async def synthesize_speech(request: TTSRequest): | |
| """ | |
| Synthesizes speech from the given text and returns it as a Base64 encoded WAV byte array. | |
| """ | |
| if model is None or tokenizer is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="TTS model is not loaded. Please try again later or check server logs." | |
| ) | |
| try: | |
| # Tokenize the input text | |
| # Ensure the text is properly handled by the tokenizer for Vietnamese | |
| inputs = tokenizer(request.text, return_tensors="pt") | |
| # Generate speech | |
| with torch.no_grad(): | |
| # speaker_id is usually not needed for single-speaker models like mms-tts-vie | |
| # if model supports it, you might pass speaker_id=request.speaker_id | |
| audio_values = model(**inputs).waveform | |
| # Convert to WAV bytes in memory | |
| # The sampling rate is critical and should match the model's config | |
| samplerate = model.config.sampling_rate | |
| output_buffer = io.BytesIO() | |
| sf.write(output_buffer, audio_values.numpy().squeeze(), samplerate, format='WAV') | |
| output_buffer.seek(0) # Rewind the buffer to the beginning | |
| # Encode the WAV bytes to Base64 string | |
| audio_base64 = base64.b64encode(output_buffer.read()).decode('utf-8') | |
| return {"audio_base64": audio_base64} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() # Print full traceback to console for debugging | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during speech synthesis: {e}" | |
| ) |