File size: 3,435 Bytes
1caa8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer
import torch
import soundfile as sf
import io
import base64
import os

app = FastAPI(
    title="MMS-TTS Vietnamese API",
    description="A simple API for Vietnamese Text-to-Speech using facebook/mms-tts-vie."
)

# Define the request body model
class TTSRequest(BaseModel):
    text: str
    speaker_id: int = 0 # MMS-TTS models are single-speaker by default, but keeping this for potential future multi-speaker models
    speed_factor: float = 1.0 # Optional: Adjust speech speed

# Global variables to hold the loaded model and tokenizer
# Avoid reloading them for every request, improving performance.
model = None
tokenizer = None

@app.on_event("startup")
async def startup_event():
    """
    Load the TTS model and tokenizer when the FastAPI application starts up.
    This ensures they are ready for immediate use and not reloaded per request.
    """
    global model, tokenizer
    try:
        print("Loading MMS-TTS model 'facebook/mms-tts-vie'...")
        model = VitsModel.from_pretrained("facebook/mms-tts-vie")
        tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie")
        print("MMS-TTS model and tokenizer loaded successfully.")
    except Exception as e:
        # Log the full exception for debugging
        import traceback
        traceback.print_exc()
        print(f"ERROR: Failed to load MMS-TTS model or tokenizer on startup: {e}")
        model = None # Ensure they are clearly unset if loading fails
        tokenizer = None

@app.get("/")
async def read_root():
    """Basic health check endpoint."""
    return {"message": "MMS-TTS Vietnamese API is running!", "model_loaded": model is not None}


@app.post("/synthesize_speech")
async def synthesize_speech(request: TTSRequest):
    """
    Synthesizes speech from the given text and returns it as a Base64 encoded WAV byte array.
    """
    if model is None or tokenizer is None:
        raise HTTPException(
            status_code=503,
            detail="TTS model is not loaded. Please try again later or check server logs."
        )

    try:
        # Tokenize the input text
        # Ensure the text is properly handled by the tokenizer for Vietnamese
        inputs = tokenizer(request.text, return_tensors="pt")

        # Generate speech
        with torch.no_grad():
            # speaker_id is usually not needed for single-speaker models like mms-tts-vie
            # if model supports it, you might pass speaker_id=request.speaker_id
            audio_values = model(**inputs).waveform

        # Convert to WAV bytes in memory
        # The sampling rate is critical and should match the model's config
        samplerate = model.config.sampling_rate
        output_buffer = io.BytesIO()
        sf.write(output_buffer, audio_values.numpy().squeeze(), samplerate, format='WAV')
        output_buffer.seek(0) # Rewind the buffer to the beginning

        # Encode the WAV bytes to Base64 string
        audio_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')

        return {"audio_base64": audio_base64}

    except Exception as e:
        import traceback
        traceback.print_exc() # Print full traceback to console for debugging
        raise HTTPException(
            status_code=500,
            detail=f"An error occurred during speech synthesis: {e}"
        )