| """ |
| Enhanced MMS-TTS API – Production-ready Text-to-Speech for 1,100+ languages. |
| Features: long text batching, auto language detection (55 languages), streaming. |
| """ |
|
|
| import asyncio |
| import io |
| import logging |
| import re |
| from contextlib import asynccontextmanager |
| from typing import List, Optional, Dict, Any, Union |
|
|
| import numpy as np |
| import torch |
| from fastapi import FastAPI, HTTPException, BackgroundTasks |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field, field_validator, ValidationInfo |
| from transformers import VitsModel, AutoTokenizer |
| from langdetect import detect, DetectorFactory |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| DetectorFactory.seed = 0 |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| SAMPLE_RATE = 16000 |
| MAX_TEXT_LEN = 1000 |
| SILENCE_DURATION = 0.1 |
|
|
| |
| LANGUAGE_MAP: Dict[str, str] = { |
| "asm": "Assamese", |
| "ben": "Bengali", |
| "eng": "English", |
| "hin": "Hindi", |
| "tam": "Tamil", |
| "tel": "Telugu", |
| "mar": "Marathi", |
| "guj": "Gujarati", |
| "kan": "Kannada", |
| "mal": "Malayalam", |
| "pan": "Punjabi", |
| "urd": "Urdu", |
| "ori": "Odia", |
| "sat": "Santali", |
| "snd": "Sindhi", |
| "mai": "Maithili", |
| "awa": "Awadhi", |
| "bho": "Bhojpuri", |
| "mag": "Magahi", |
| "raj": "Rajasthani", |
| "hne": "Chhattisgarhi", |
| "kha": "Khasi", |
| "lus": "Mizo", |
| "grt": "Garo", |
| "brx": "Bodo", |
| "mni": "Manipuri", |
| "nep": "Nepali", |
| "sin": "Sinhala", |
| "tuk": "Turkmen", |
| "uig": "Uyghur", |
| "rus": "Russian", |
| "fra": "French", |
| "deu": "German", |
| "spa": "Spanish", |
| "ita": "Italian", |
| "por": "Portuguese", |
| "jpn": "Japanese", |
| "kor": "Korean", |
| "zho": "Chinese", |
| "ara": "Arabic", |
| "fas": "Persian", |
| "tur": "Turkish", |
| "pol": "Polish", |
| "ukr": "Ukrainian", |
| "vie": "Vietnamese", |
| "tha": "Thai", |
| "ind": "Indonesian", |
| "msa": "Malay", |
| "swa": "Swahili", |
| "amh": "Amharic", |
| "yor": "Yoruba", |
| "ibo": "Igbo", |
| "hau": "Hausa", |
| "sna": "Shona", |
| "zul": "Zulu", |
| "som": "Somali", |
| "khm": "Khmer", |
| "mya": "Burmese", |
| "lao": "Lao", |
| "mon": "Mongolian", |
| "kaz": "Kazakh", |
| "uzb": "Uzbek", |
| "tgl": "Tagalog", |
| "ceb": "Cebuano", |
| "hmn": "Hmong", |
| "nav": "Navajo", |
| "smo": "Samoan", |
| "tah": "Tahitian", |
| "haw": "Hawaiian", |
| "mlt": "Maltese", |
| "est": "Estonian", |
| "lav": "Latvian", |
| "lit": "Lithuanian", |
| "slv": "Slovenian", |
| "hrv": "Croatian", |
| "srp": "Serbian", |
| "bos": "Bosnian", |
| "mkd": "Macedonian", |
| "bul": "Bulgarian", |
| "ron": "Romanian", |
| "hun": "Hungarian", |
| "fin": "Finnish", |
| "swe": "Swedish", |
| "nor": "Norwegian", |
| "dan": "Danish", |
| "isl": "Icelandic", |
| "gle": "Irish", |
| "cym": "Welsh", |
| "eus": "Basque", |
| "cat": "Catalan", |
| "glg": "Galician", |
| "heb": "Hebrew", |
| "hye": "Armenian", |
| "kat": "Georgian", |
| "san": "Sanskrit", |
| } |
|
|
| |
| LANGDETECT_TO_MMS = { |
| "en": "eng", "hi": "hin", "bn": "ben", "as": "asm", "ta": "tam", |
| "te": "tel", "mr": "mar", "gu": "guj", "kn": "kan", "ml": "mal", |
| "pa": "pan", "ur": "urd", "or": "ori", "ne": "nep", "si": "sin", |
| "ru": "rus", "fr": "fra", "de": "deu", "es": "spa", "it": "ita", |
| "pt": "por", "ja": "jpn", "ko": "kor", "zh-cn": "zho", "ar": "ara", |
| "fa": "fas", "tr": "tur", "pl": "pol", "uk": "ukr", "vi": "vie", |
| "th": "tha", "id": "ind", "ms": "msa", "sw": "swa", "am": "amh", |
| "yo": "yor", "ig": "ibo", "ha": "hau", "sn": "sna", "zu": "zul", |
| "so": "som", "km": "khm", "my": "mya", "lo": "lao", "mn": "mon", |
| "kk": "kaz", "uz": "uzb", "tl": "tgl", "ceb": "ceb", "hmn": "hmn", |
| "nv": "nav", "sm": "smo", "ty": "tah", "haw": "haw", "mt": "mlt", |
| "et": "est", "lv": "lav", "lt": "lit", "sl": "slv", "hr": "hrv", |
| "sr": "srp", "bs": "bos", "mk": "mkd", "bg": "bul", "ro": "ron", |
| "hu": "hun", "fi": "fin", "sv": "swe", "no": "nor", "da": "dan", |
| "is": "isl", "ga": "gle", "cy": "cym", "eu": "eus", "ca": "cat", |
| "gl": "glg", "he": "heb", "hy": "hye", "ka": "kat", "sa": "san", |
| } |
|
|
| |
| model_cache: Dict[str, Any] = {} |
| tokenizer_cache: Dict[str, Any] = {} |
|
|
| def get_model_and_tokenizer(language_code: str): |
| """Load or retrieve from cache the VitsModel and tokenizer for a given language.""" |
| if language_code not in model_cache: |
| model_id = f"facebook/mms-tts-{language_code}" |
| logger.info(f"Loading model: {model_id}") |
| try: |
| model = VitsModel.from_pretrained(model_id) |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model.to(DEVICE) |
| model.eval() |
| model_cache[language_code] = model |
| tokenizer_cache[language_code] = tokenizer |
| logger.info(f"Successfully loaded model for {language_code}") |
| except Exception as e: |
| logger.error(f"Failed to load model for {language_code}: {e}") |
| raise HTTPException(status_code=404, detail=f"Language code '{language_code}' not supported.") |
| return model_cache[language_code], tokenizer_cache[language_code] |
|
|
| def detect_language(text: str) -> str: |
| """Detect language using langdetect, return MMS-compatible ISO 639-3 code.""" |
| try: |
| lang_code = detect(text) |
| mms_code = LANGDETECT_TO_MMS.get(lang_code, lang_code) |
| return mms_code |
| except Exception as e: |
| logger.warning(f"Language detection failed: {e}, falling back to English") |
| return "eng" |
|
|
| def split_text_into_chunks(text: str, max_len: int = MAX_TEXT_LEN) -> List[str]: |
| """Split long text into chunks at sentence boundaries.""" |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| chunks = [] |
| current_chunk = "" |
| for sent in sentences: |
| if len(current_chunk) + len(sent) + 1 <= max_len: |
| if current_chunk: |
| current_chunk += " " + sent |
| else: |
| current_chunk = sent |
| else: |
| if current_chunk: |
| chunks.append(current_chunk) |
| current_chunk = sent |
| if current_chunk: |
| chunks.append(current_chunk) |
| return chunks |
|
|
| def synthesize_single_chunk(text: str, language_code: str, speed: float = 1.0) -> np.ndarray: |
| """Generate audio for a single text chunk.""" |
| model, tokenizer = get_model_and_tokenizer(language_code) |
| inputs = tokenizer(text, return_tensors="pt") |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| with torch.inference_mode(): |
| outputs = model(**inputs) |
| waveform = outputs.waveform[0].cpu().numpy() |
| if speed != 1.0: |
| new_length = int(len(waveform) / speed) |
| indices = np.linspace(0, len(waveform)-1, new_length) |
| waveform = np.interp(indices, np.arange(len(waveform)), waveform) |
| return waveform |
|
|
| def synthesize_batch(texts: List[str], language_code: str, speed: float = 1.0) -> np.ndarray: |
| """Synthesize a batch of text chunks and concatenate with silence.""" |
| audios = [] |
| for i, txt in enumerate(texts): |
| audio = synthesize_single_chunk(txt, language_code, speed) |
| audios.append(audio) |
| if i < len(texts) - 1: |
| silence = np.zeros(int(SAMPLE_RATE * SILENCE_DURATION)) |
| audios.append(silence) |
| return np.concatenate(audios) |
|
|
| |
| class TTSRequest(BaseModel): |
| text: Union[str, List[str]] = Field( |
| ..., |
| description="Text to be spoken. Can be a single string (auto-chunked if long) or an array of strings." |
| ) |
| language: Optional[str] = Field( |
| None, |
| description="ISO 639-3 language code (e.g., 'eng', 'asm', 'ben'). Required if auto_detect=False." |
| ) |
| auto_detect: bool = Field( |
| False, |
| description="Automatically detect language from the first text chunk. Overrides 'language' if set to True." |
| ) |
| voice: str = Field( |
| "default", |
| description="Voice identifier. Currently only 'default' is supported." |
| ) |
| speed: float = Field( |
| 1.0, |
| ge=0.5, |
| le=2.0, |
| description="Speech speed multiplier (0.5 = half speed, 2.0 = double speed)." |
| ) |
|
|
| @field_validator("language") |
| def check_language(cls, v, info: ValidationInfo): |
| if not info.data.get("auto_detect", False) and v is None: |
| raise ValueError("Either provide 'language' or set 'auto_detect=True'.") |
| return v |
|
|
| @field_validator("text") |
| def check_text(cls, v): |
| if isinstance(v, str) and len(v.strip()) == 0: |
| raise ValueError("Text cannot be empty.") |
| if isinstance(v, list) and all(len(t.strip()) == 0 for t in v): |
| raise ValueError("All text entries are empty.") |
| return v |
|
|
| class TTSResponse(BaseModel): |
| audio_format: str = "audio/wav" |
| language_detected: Optional[str] = Field(None, description="Detected language code (if auto_detect=True).") |
| language_used: str = Field(..., description="Language code used for synthesis.") |
| num_chunks: int = Field(..., description="Number of audio chunks concatenated.") |
| duration_seconds: float = Field(..., description="Total duration of generated audio in seconds.") |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| logger.info("Enhanced MMS-TTS API starting. Models will be loaded on first request.") |
| yield |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
|
|
| app = FastAPI( |
| title="Enhanced MMS-TTS API", |
| description="Production-grade Text-to-Speech for 1,100+ languages using Facebook's MMS models.", |
| version="2.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| |
| @app.get("/health", summary="Health Check", tags=["System"]) |
| async def health_check() -> Dict[str, str]: |
| return {"status": "ok", "device": DEVICE, "sample_rate": str(SAMPLE_RATE)} |
|
|
| @app.get("/languages", summary="List Supported Languages", tags=["Metadata"]) |
| async def get_languages() -> Dict[str, Any]: |
| languages = [{"code": code, "name": name} for code, name in sorted(LANGUAGE_MAP.items())] |
| return {"total_languages": len(languages), "languages": languages} |
|
|
| @app.post("/tts", summary="Convert Text to Speech", tags=["Synthesis"]) |
| async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks) -> StreamingResponse: |
| try: |
| |
| detected_lang = None |
| if request.auto_detect: |
| sample_text = request.text if isinstance(request.text, str) else next((t for t in request.text if t.strip()), "") |
| if not sample_text: |
| raise HTTPException(status_code=400, detail="Cannot auto-detect language from empty text.") |
| detected_lang = detect_language(sample_text) |
| language_code = detected_lang |
| else: |
| language_code = request.language |
|
|
| |
| if isinstance(request.text, str): |
| chunks = split_text_into_chunks(request.text) |
| else: |
| chunks = [] |
| for part in request.text: |
| if len(part) > MAX_TEXT_LEN: |
| chunks.extend(split_text_into_chunks(part)) |
| else: |
| chunks.append(part) |
|
|
| if not chunks: |
| raise HTTPException(status_code=400, detail="No valid text chunks after processing.") |
|
|
| |
| audio_np = await asyncio.to_thread(synthesize_batch, chunks, language_code, request.speed) |
|
|
| |
| audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype("int16") |
| buffer = io.BytesIO() |
| import wave |
| with wave.open(buffer, "wb") as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(SAMPLE_RATE) |
| wav_file.writeframes(audio_int16.tobytes()) |
| buffer.seek(0) |
| duration = len(audio_np) / SAMPLE_RATE |
|
|
| headers = { |
| "Content-Disposition": "attachment; filename=speech.wav", |
| "Content-Type": "audio/wav", |
| "X-Language-Used": language_code, |
| "X-Detected-Language": detected_lang or "", |
| "X-Num-Chunks": str(len(chunks)), |
| "X-Duration-Seconds": f"{duration:.2f}", |
| } |
| return StreamingResponse(buffer, media_type="audio/wav", headers=headers, background=background_tasks) |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Synthesis error: {type(e).__name__}: {e}") |
| raise HTTPException(status_code=500, detail=f"Speech synthesis failed: {str(e)}") |
|
|
| @app.post("/tts/info", summary="Get synthesis info without audio", tags=["Synthesis"]) |
| async def tts_info(request: TTSRequest) -> TTSResponse: |
| """Return metadata about the synthesis without generating audio.""" |
| try: |
| if request.auto_detect: |
| sample_text = request.text if isinstance(request.text, str) else next((t for t in request.text if t.strip()), "") |
| if not sample_text: |
| raise HTTPException(status_code=400, detail="Cannot auto-detect language from empty text.") |
| detected_lang = detect_language(sample_text) |
| language_code = detected_lang |
| else: |
| language_code = request.language |
| detected_lang = None |
|
|
| if isinstance(request.text, str): |
| chunks = split_text_into_chunks(request.text) |
| else: |
| chunks = [] |
| for part in request.text: |
| if len(part) > MAX_TEXT_LEN: |
| chunks.extend(split_text_into_chunks(part)) |
| else: |
| chunks.append(part) |
|
|
| return TTSResponse( |
| language_detected=detected_lang, |
| language_used=language_code, |
| num_chunks=len(chunks), |
| duration_seconds=0.0, |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |