fbmmstts / app.py
sidmaz666's picture
Update app.py
5a284f4 verified
"""
Enhanced MMS-TTS API – Production-ready Text-to-Speech for 1,100+ languages.
Features: long text batching, auto language detection (55 languages), streaming.
"""
import asyncio
import io
import logging
import re
from contextlib import asynccontextmanager
from typing import List, Optional, Dict, Any, Union
import numpy as np
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from transformers import VitsModel, AutoTokenizer
from langdetect import detect, DetectorFactory
# ------------------------- Logging -------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set seed for langdetect to make results consistent
DetectorFactory.seed = 0
# ------------------------- Configuration -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 16000
MAX_TEXT_LEN = 1000 # characters per chunk (safe for MMS token limit)
SILENCE_DURATION = 0.1 # seconds of silence between chunks
# Language mapping (ISO 639-3 -> human name)
LANGUAGE_MAP: Dict[str, str] = {
"asm": "Assamese",
"ben": "Bengali",
"eng": "English",
"hin": "Hindi",
"tam": "Tamil",
"tel": "Telugu",
"mar": "Marathi",
"guj": "Gujarati",
"kan": "Kannada",
"mal": "Malayalam",
"pan": "Punjabi",
"urd": "Urdu",
"ori": "Odia",
"sat": "Santali",
"snd": "Sindhi",
"mai": "Maithili",
"awa": "Awadhi",
"bho": "Bhojpuri",
"mag": "Magahi",
"raj": "Rajasthani",
"hne": "Chhattisgarhi",
"kha": "Khasi",
"lus": "Mizo",
"grt": "Garo",
"brx": "Bodo",
"mni": "Manipuri",
"nep": "Nepali",
"sin": "Sinhala",
"tuk": "Turkmen",
"uig": "Uyghur",
"rus": "Russian",
"fra": "French",
"deu": "German",
"spa": "Spanish",
"ita": "Italian",
"por": "Portuguese",
"jpn": "Japanese",
"kor": "Korean",
"zho": "Chinese",
"ara": "Arabic",
"fas": "Persian",
"tur": "Turkish",
"pol": "Polish",
"ukr": "Ukrainian",
"vie": "Vietnamese",
"tha": "Thai",
"ind": "Indonesian",
"msa": "Malay",
"swa": "Swahili",
"amh": "Amharic",
"yor": "Yoruba",
"ibo": "Igbo",
"hau": "Hausa",
"sna": "Shona",
"zul": "Zulu",
"som": "Somali",
"khm": "Khmer",
"mya": "Burmese",
"lao": "Lao",
"mon": "Mongolian",
"kaz": "Kazakh",
"uzb": "Uzbek",
"tgl": "Tagalog",
"ceb": "Cebuano",
"hmn": "Hmong",
"nav": "Navajo",
"smo": "Samoan",
"tah": "Tahitian",
"haw": "Hawaiian",
"mlt": "Maltese",
"est": "Estonian",
"lav": "Latvian",
"lit": "Lithuanian",
"slv": "Slovenian",
"hrv": "Croatian",
"srp": "Serbian",
"bos": "Bosnian",
"mkd": "Macedonian",
"bul": "Bulgarian",
"ron": "Romanian",
"hun": "Hungarian",
"fin": "Finnish",
"swe": "Swedish",
"nor": "Norwegian",
"dan": "Danish",
"isl": "Icelandic",
"gle": "Irish",
"cym": "Welsh",
"eus": "Basque",
"cat": "Catalan",
"glg": "Galician",
"heb": "Hebrew",
"hye": "Armenian",
"kat": "Georgian",
"san": "Sanskrit",
}
# Map langdetect ISO 639-1 codes to MMS ISO 639-3 codes
LANGDETECT_TO_MMS = {
"en": "eng", "hi": "hin", "bn": "ben", "as": "asm", "ta": "tam",
"te": "tel", "mr": "mar", "gu": "guj", "kn": "kan", "ml": "mal",
"pa": "pan", "ur": "urd", "or": "ori", "ne": "nep", "si": "sin",
"ru": "rus", "fr": "fra", "de": "deu", "es": "spa", "it": "ita",
"pt": "por", "ja": "jpn", "ko": "kor", "zh-cn": "zho", "ar": "ara",
"fa": "fas", "tr": "tur", "pl": "pol", "uk": "ukr", "vi": "vie",
"th": "tha", "id": "ind", "ms": "msa", "sw": "swa", "am": "amh",
"yo": "yor", "ig": "ibo", "ha": "hau", "sn": "sna", "zu": "zul",
"so": "som", "km": "khm", "my": "mya", "lo": "lao", "mn": "mon",
"kk": "kaz", "uz": "uzb", "tl": "tgl", "ceb": "ceb", "hmn": "hmn",
"nv": "nav", "sm": "smo", "ty": "tah", "haw": "haw", "mt": "mlt",
"et": "est", "lv": "lav", "lt": "lit", "sl": "slv", "hr": "hrv",
"sr": "srp", "bs": "bos", "mk": "mkd", "bg": "bul", "ro": "ron",
"hu": "hun", "fi": "fin", "sv": "swe", "no": "nor", "da": "dan",
"is": "isl", "ga": "gle", "cy": "cym", "eu": "eus", "ca": "cat",
"gl": "glg", "he": "heb", "hy": "hye", "ka": "kat", "sa": "san",
}
# ------------------------- Model Cache -------------------------
model_cache: Dict[str, Any] = {}
tokenizer_cache: Dict[str, Any] = {}
def get_model_and_tokenizer(language_code: str):
"""Load or retrieve from cache the VitsModel and tokenizer for a given language."""
if language_code not in model_cache:
model_id = f"facebook/mms-tts-{language_code}"
logger.info(f"Loading model: {model_id}")
try:
model = VitsModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.to(DEVICE)
model.eval()
model_cache[language_code] = model
tokenizer_cache[language_code] = tokenizer
logger.info(f"Successfully loaded model for {language_code}")
except Exception as e:
logger.error(f"Failed to load model for {language_code}: {e}")
raise HTTPException(status_code=404, detail=f"Language code '{language_code}' not supported.")
return model_cache[language_code], tokenizer_cache[language_code]
def detect_language(text: str) -> str:
"""Detect language using langdetect, return MMS-compatible ISO 639-3 code."""
try:
lang_code = detect(text)
mms_code = LANGDETECT_TO_MMS.get(lang_code, lang_code)
return mms_code
except Exception as e:
logger.warning(f"Language detection failed: {e}, falling back to English")
return "eng"
def split_text_into_chunks(text: str, max_len: int = MAX_TEXT_LEN) -> List[str]:
"""Split long text into chunks at sentence boundaries."""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sent in sentences:
if len(current_chunk) + len(sent) + 1 <= max_len:
if current_chunk:
current_chunk += " " + sent
else:
current_chunk = sent
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sent
if current_chunk:
chunks.append(current_chunk)
return chunks
def synthesize_single_chunk(text: str, language_code: str, speed: float = 1.0) -> np.ndarray:
"""Generate audio for a single text chunk."""
model, tokenizer = get_model_and_tokenizer(language_code)
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.inference_mode():
outputs = model(**inputs)
waveform = outputs.waveform[0].cpu().numpy()
if speed != 1.0:
new_length = int(len(waveform) / speed)
indices = np.linspace(0, len(waveform)-1, new_length)
waveform = np.interp(indices, np.arange(len(waveform)), waveform)
return waveform
def synthesize_batch(texts: List[str], language_code: str, speed: float = 1.0) -> np.ndarray:
"""Synthesize a batch of text chunks and concatenate with silence."""
audios = []
for i, txt in enumerate(texts):
audio = synthesize_single_chunk(txt, language_code, speed)
audios.append(audio)
if i < len(texts) - 1:
silence = np.zeros(int(SAMPLE_RATE * SILENCE_DURATION))
audios.append(silence)
return np.concatenate(audios)
# ------------------------- Pydantic Models -------------------------
class TTSRequest(BaseModel):
text: Union[str, List[str]] = Field(
...,
description="Text to be spoken. Can be a single string (auto-chunked if long) or an array of strings."
)
language: Optional[str] = Field(
None,
description="ISO 639-3 language code (e.g., 'eng', 'asm', 'ben'). Required if auto_detect=False."
)
auto_detect: bool = Field(
False,
description="Automatically detect language from the first text chunk. Overrides 'language' if set to True."
)
voice: str = Field(
"default",
description="Voice identifier. Currently only 'default' is supported."
)
speed: float = Field(
1.0,
ge=0.5,
le=2.0,
description="Speech speed multiplier (0.5 = half speed, 2.0 = double speed)."
)
@field_validator("language")
def check_language(cls, v, info: ValidationInfo):
if not info.data.get("auto_detect", False) and v is None:
raise ValueError("Either provide 'language' or set 'auto_detect=True'.")
return v
@field_validator("text")
def check_text(cls, v):
if isinstance(v, str) and len(v.strip()) == 0:
raise ValueError("Text cannot be empty.")
if isinstance(v, list) and all(len(t.strip()) == 0 for t in v):
raise ValueError("All text entries are empty.")
return v
class TTSResponse(BaseModel):
audio_format: str = "audio/wav"
language_detected: Optional[str] = Field(None, description="Detected language code (if auto_detect=True).")
language_used: str = Field(..., description="Language code used for synthesis.")
num_chunks: int = Field(..., description="Number of audio chunks concatenated.")
duration_seconds: float = Field(..., description="Total duration of generated audio in seconds.")
# ------------------------- Lifespan -------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Enhanced MMS-TTS API starting. Models will be loaded on first request.")
yield
if DEVICE == "cuda":
torch.cuda.empty_cache()
app = FastAPI(
title="Enhanced MMS-TTS API",
description="Production-grade Text-to-Speech for 1,100+ languages using Facebook's MMS models.",
version="2.0.0",
lifespan=lifespan,
)
# ------------------------- API Endpoints -------------------------
@app.get("/health", summary="Health Check", tags=["System"])
async def health_check() -> Dict[str, str]:
return {"status": "ok", "device": DEVICE, "sample_rate": str(SAMPLE_RATE)} # Convert to string
@app.get("/languages", summary="List Supported Languages", tags=["Metadata"])
async def get_languages() -> Dict[str, Any]:
languages = [{"code": code, "name": name} for code, name in sorted(LANGUAGE_MAP.items())]
return {"total_languages": len(languages), "languages": languages}
@app.post("/tts", summary="Convert Text to Speech", tags=["Synthesis"])
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks) -> StreamingResponse:
try:
# Determine language
detected_lang = None
if request.auto_detect:
sample_text = request.text if isinstance(request.text, str) else next((t for t in request.text if t.strip()), "")
if not sample_text:
raise HTTPException(status_code=400, detail="Cannot auto-detect language from empty text.")
detected_lang = detect_language(sample_text)
language_code = detected_lang
else:
language_code = request.language
# Prepare text chunks
if isinstance(request.text, str):
chunks = split_text_into_chunks(request.text)
else:
chunks = []
for part in request.text:
if len(part) > MAX_TEXT_LEN:
chunks.extend(split_text_into_chunks(part))
else:
chunks.append(part)
if not chunks:
raise HTTPException(status_code=400, detail="No valid text chunks after processing.")
# Synthesize
audio_np = await asyncio.to_thread(synthesize_batch, chunks, language_code, request.speed)
# Convert to WAV bytes
audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype("int16")
buffer = io.BytesIO()
import wave
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(SAMPLE_RATE)
wav_file.writeframes(audio_int16.tobytes())
buffer.seek(0)
duration = len(audio_np) / SAMPLE_RATE
headers = {
"Content-Disposition": "attachment; filename=speech.wav",
"Content-Type": "audio/wav",
"X-Language-Used": language_code,
"X-Detected-Language": detected_lang or "",
"X-Num-Chunks": str(len(chunks)),
"X-Duration-Seconds": f"{duration:.2f}",
}
return StreamingResponse(buffer, media_type="audio/wav", headers=headers, background=background_tasks)
except HTTPException:
raise
except Exception as e:
logger.error(f"Synthesis error: {type(e).__name__}: {e}")
raise HTTPException(status_code=500, detail=f"Speech synthesis failed: {str(e)}")
@app.post("/tts/info", summary="Get synthesis info without audio", tags=["Synthesis"])
async def tts_info(request: TTSRequest) -> TTSResponse:
"""Return metadata about the synthesis without generating audio."""
try:
if request.auto_detect:
sample_text = request.text if isinstance(request.text, str) else next((t for t in request.text if t.strip()), "")
if not sample_text:
raise HTTPException(status_code=400, detail="Cannot auto-detect language from empty text.")
detected_lang = detect_language(sample_text)
language_code = detected_lang
else:
language_code = request.language
detected_lang = None
if isinstance(request.text, str):
chunks = split_text_into_chunks(request.text)
else:
chunks = []
for part in request.text:
if len(part) > MAX_TEXT_LEN:
chunks.extend(split_text_into_chunks(part))
else:
chunks.append(part)
return TTSResponse(
language_detected=detected_lang,
language_used=language_code,
num_chunks=len(chunks),
duration_seconds=0.0, # Not computed without synthesis
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))