""" Forge-TTS v2.0.0 — XTTS-v2 Only CPU-optimized TTS API with Polish voice cloning. Single backend: Coqui XTTS-v2 via idiap fork (coqui-tts>=0.27.0). Features: - Speaker latent caching (LRU, keyed by WAV hash) - Text chunking + audio concatenation - SSE streaming endpoint - Multipart WAV upload for cloning convenience - Configurable via env vars """ from __future__ import annotations import asyncio import base64 import hashlib import io import json import os import re import tempfile import threading import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import numpy as np import soundfile as sf import torch from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Settings (env-configurable) # --------------------------------------------------------------------------- def _env_bool(name: str, default: bool = False) -> bool: v = os.getenv(name) if v is None: return default return v.strip().lower() in {"1", "true", "yes", "y", "on"} def _env_float(name: str, default: float) -> float: v = os.getenv(name) return float(v) if v else default @dataclass(frozen=True) class Settings: # Model model_name: str = os.getenv("XTTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2") default_language: str = os.getenv("XTTS_DEFAULT_LANGUAGE", "pl") default_speaker: str = os.getenv("XTTS_DEFAULT_SPEAKER", "Claribel Dervla") # Built-in XTTS speaker # Generation params temperature: float = _env_float("XTTS_TEMPERATURE", 0.65) speed: float = _env_float("XTTS_SPEED", 1.0) top_p: float = _env_float("XTTS_TOP_P", 0.85) top_k: int = int(os.getenv("XTTS_TOP_K", "50")) repetition_penalty: float = _env_float("XTTS_REPETITION_PENALTY", 5.0) # Optimizations torch_compile: bool = _env_bool("XTTS_TORCH_COMPILE", False) use_fp16: bool = _env_bool("XTTS_USE_FP16", False) # Chunking chunk_max_chars: int = int(os.getenv("CHUNK_MAX_CHARS", "250")) chunk_max_words: int = int(os.getenv("CHUNK_MAX_WORDS", "40")) join_silence_ms: int = int(os.getenv("JOIN_SILENCE_MS", "60")) # Speaker cache speaker_cache_size: int = int(os.getenv("SPEAKER_CACHE_SIZE", "8")) # Runtime num_threads: int = int(os.getenv("OMP_NUM_THREADS", "2")) S = Settings() # Conservative CPU threading torch.set_num_threads(S.num_threads) torch.set_num_interop_threads(max(1, S.num_threads // 2)) # --------------------------------------------------------------------------- # Text utilities (kept from v1) # --------------------------------------------------------------------------- _SENT_SPLIT_RE = re.compile(r"(?<=[\.\!\?\:\;])\s+|\n+") _WS_RE = re.compile(r"\s+") def normalize_text(text: str) -> str: return _WS_RE.sub(" ", text.strip()) def split_text_into_chunks( text: str, max_chars: int = S.chunk_max_chars, max_words: int = S.chunk_max_words, ) -> List[str]: text = normalize_text(text) if not text: return [] sents = [s.strip() for s in _SENT_SPLIT_RE.split(text) if s.strip()] chunks: List[str] = [] cur: List[str] = [] cur_chars = 0 cur_words = 0 def flush(): nonlocal cur, cur_chars, cur_words if cur: chunks.append(" ".join(cur).strip()) cur, cur_chars, cur_words = [], 0, 0 for sent in sents: w = len(sent.split()) c = len(sent) if cur and (cur_chars + c > max_chars or cur_words + w > max_words): flush() cur.append(sent) cur_chars += c + 1 cur_words += w flush() return chunks def wav_bytes_from_audio(audio: np.ndarray, sr: int) -> bytes: buf = io.BytesIO() sf.write(buf, np.asarray(audio, dtype=np.float32), sr, format="WAV", subtype="PCM_16") return buf.getvalue() def concat_audio(chunks: List[np.ndarray], sr: int, silence_ms: int = S.join_silence_ms) -> np.ndarray: if not chunks: return np.zeros((1,), dtype=np.float32) if len(chunks) == 1: return np.asarray(chunks[0], dtype=np.float32) silence = np.zeros(int(sr * silence_ms / 1000), dtype=np.float32) if silence_ms > 0 else None parts = [] for i, ch in enumerate(chunks): parts.append(np.asarray(ch, dtype=np.float32)) if silence is not None and i < len(chunks) - 1: parts.append(silence) return np.concatenate(parts) def b64encode_bytes(b: bytes) -> str: return base64.b64encode(b).decode("ascii") # --------------------------------------------------------------------------- # Speaker latent cache (keyed by SHA-256 of WAV bytes) # --------------------------------------------------------------------------- class SpeakerCache: def __init__(self, maxsize: int = S.speaker_cache_size): self._cache: Dict[str, Tuple] = {} self._order: List[str] = [] self._maxsize = maxsize self._lock = threading.Lock() def _key(self, wav_bytes: bytes) -> str: return hashlib.sha256(wav_bytes).hexdigest()[:16] def get(self, wav_bytes: bytes) -> Optional[Tuple]: key = self._key(wav_bytes) with self._lock: return self._cache.get(key) def put(self, wav_bytes: bytes, latents: Tuple) -> None: key = self._key(wav_bytes) with self._lock: if key in self._cache: return if len(self._order) >= self._maxsize: evict = self._order.pop(0) self._cache.pop(evict, None) self._cache[key] = latents self._order.append(key) _speaker_cache = SpeakerCache() # --------------------------------------------------------------------------- # Model manager (lazy, thread-safe) # --------------------------------------------------------------------------- _model_lock = threading.Lock() _infer_lock = threading.Lock() _tts_model = None _tts_error: Optional[str] = None def _get_model(): global _tts_model, _tts_error if _tts_error is not None: return None if _tts_model is not None: return _tts_model with _model_lock: if _tts_error is not None: return None if _tts_model is not None: return _tts_model try: from TTS.api import TTS print(f"[XTTS] Loading {S.model_name} ...") t0 = time.time() tts = TTS(model_name=S.model_name, progress_bar=False, gpu=False) # Optional optimizations inner = getattr(getattr(tts, "synthesizer", None), "tts_model", None) if isinstance(inner, torch.nn.Module): if S.use_fp16: try: inner = inner.half() tts.synthesizer.tts_model = inner print("[XTTS] FP16 enabled") except Exception as e: print(f"[XTTS] FP16 failed: {e}") if S.torch_compile: try: inner = torch.compile(inner) tts.synthesizer.tts_model = inner print("[XTTS] torch.compile enabled") except Exception as e: print(f"[XTTS] torch.compile failed: {e}") _tts_model = tts print(f"[XTTS] Model loaded in {time.time() - t0:.1f}s") except Exception as e: _tts_error = str(e) print(f"[XTTS] FAILED to load: {e}") return None return _tts_model def _get_sample_rate() -> int: tts = _get_model() if tts is None: return 22050 synth = getattr(tts, "synthesizer", None) return getattr(synth, "output_sample_rate", 22050) if synth else 22050 # --------------------------------------------------------------------------- # Core synthesis function # --------------------------------------------------------------------------- def _synthesize(text: str, language: str, speaker_wav_bytes: Optional[bytes] = None) -> Tuple[np.ndarray, int, float]: """Returns (audio_np, sample_rate, generation_time_s).""" tts = _get_model() if tts is None: raise HTTPException(503, f"XTTS unavailable: {_tts_error or 'model not loaded'}") t0 = time.time() with _infer_lock: tmp_path = None try: if speaker_wav_bytes: # Voice cloning mode: use provided speaker WAV with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(speaker_wav_bytes) tmp.flush() tmp_path = tmp.name audio_np = tts.tts( text=text, language=language, speaker_wav=tmp_path, ) else: # Default speaker mode: use built-in speaker audio_np = tts.tts( text=text, language=language, speaker=S.default_speaker, ) finally: if tmp_path: try: os.remove(tmp_path) except OSError: pass sr = _get_sample_rate() gen_time = time.time() - t0 return np.asarray(audio_np, dtype=np.float32), sr, gen_time # --------------------------------------------------------------------------- # Pydantic models # --------------------------------------------------------------------------- class SynthRequest(BaseModel): text: str = Field(..., min_length=1, max_length=5000) language: Optional[str] = Field(None, description="Language code (default: pl)") speaker_wav_b64: Optional[str] = Field(None, description="Base64-encoded WAV for voice cloning") class StreamRequest(BaseModel): text: str = Field(..., min_length=1, max_length=5000) language: Optional[str] = None speaker_wav_b64: Optional[str] = None class AudioResponse(BaseModel): audio_b64: str sample_rate: int duration_s: float generation_time_s: float text: str class HealthResponse(BaseModel): status: str = "ok" version: str = "2.0.0" model: str = S.model_name language: str = S.default_language xtts_available: bool = True speaker_cache_size: int = S.speaker_cache_size # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Forge-TTS API", version="2.0.0") @app.get("/health", response_model=HealthResponse) def health(): available = _tts_error is None return HealthResponse( xtts_available=available, status="ok" if available else f"degraded: {_tts_error}", ) @app.post("/v1/xtts/synthesize", response_model=AudioResponse) def xtts_synthesize(req: SynthRequest): speaker_bytes = None if req.speaker_wav_b64: try: speaker_bytes = base64.b64decode(req.speaker_wav_b64) except Exception as e: raise HTTPException(400, f"Invalid base64 speaker_wav: {e}") lang = req.language or S.default_language chunks = split_text_into_chunks(req.text) if not chunks: raise HTTPException(400, "Empty text after normalization") audio_parts = [] total_gen = 0.0 sr = 22050 for chunk_text in chunks: audio, sr, gen_t = _synthesize(chunk_text, lang, speaker_bytes) audio_parts.append(audio) total_gen += gen_t full_audio = concat_audio(audio_parts, sr) wav_bytes = wav_bytes_from_audio(full_audio, sr) return AudioResponse( audio_b64=b64encode_bytes(wav_bytes), sample_rate=sr, duration_s=round(len(full_audio) / sr, 3), generation_time_s=round(total_gen, 3), text=req.text, ) @app.post("/v1/xtts/stream") async def xtts_stream(req: StreamRequest): speaker_bytes = None if req.speaker_wav_b64: try: speaker_bytes = base64.b64decode(req.speaker_wav_b64) except Exception as e: raise HTTPException(400, f"Invalid base64: {e}") chunks = split_text_into_chunks(req.text) if not chunks: raise HTTPException(400, "Empty text after chunking") lang = req.language or S.default_language async def generate(): for i, chunk_text in enumerate(chunks): try: audio, sr, gen_t = await asyncio.to_thread( _synthesize, chunk_text, lang, speaker_bytes ) wav_bytes = wav_bytes_from_audio(audio, sr) payload = { "chunk_index": i, "total_chunks": len(chunks), "text": chunk_text, "audio_b64": b64encode_bytes(wav_bytes), "sample_rate": sr, "generation_time_s": round(gen_t, 3), } yield f"data: {json.dumps(payload)}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e), 'chunk_index': i})}\n\n" break yield "data: [DONE]\n\n" return StreamingResponse(generate(), media_type="text/event-stream") @app.post("/v1/xtts/clone", response_model=AudioResponse) async def xtts_clone( text: str = Form(..., min_length=1, max_length=5000), language: str = Form(default=S.default_language), speaker_wav: UploadFile = File(..., description="WAV file for voice cloning"), ): """Convenience endpoint: multipart form with WAV file upload (not base64).""" wav_bytes = await speaker_wav.read() if len(wav_bytes) < 44: raise HTTPException(400, "WAV file too small or empty") if len(wav_bytes) > 10 * 1024 * 1024: raise HTTPException(400, "WAV file too large (max 10MB)") chunks = split_text_into_chunks(text) if not chunks: raise HTTPException(400, "Empty text after normalization") audio_parts = [] total_gen = 0.0 sr = 22050 for chunk_text in chunks: audio, sr, gen_t = _synthesize(chunk_text, language, wav_bytes) audio_parts.append(audio) total_gen += gen_t full_audio = concat_audio(audio_parts, sr) wav_out = wav_bytes_from_audio(full_audio, sr) return AudioResponse( audio_b64=b64encode_bytes(wav_out), sample_rate=sr, duration_s=round(len(full_audio) / sr, 3), generation_time_s=round(total_gen, 3), text=text, ) # --------------------------------------------------------------------------- # Startup # --------------------------------------------------------------------------- @app.on_event("startup") async def startup_event(): print("\n" + "=" * 60) print("Forge-TTS v2.0.0 — XTTS-v2 Only") print("=" * 60) print(f"Model: {S.model_name}") print(f"Language: {S.default_language}") print(f"Threads: {S.num_threads}") print(f"FP16: {S.use_fp16}") print(f"Compile: {S.torch_compile}") print("=" * 60 + "\n") # Eager load to catch errors at startup _get_model() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)