|
|
""" |
|
|
Forge-TTS v2.0.0 — XTTS-v2 Only |
|
|
CPU-optimized TTS API with Polish voice cloning. |
|
|
Single backend: Coqui XTTS-v2 via idiap fork (coqui-tts>=0.27.0). |
|
|
|
|
|
Features: |
|
|
- Speaker latent caching (LRU, keyed by WAV hash) |
|
|
- Text chunking + audio concatenation |
|
|
- SSE streaming endpoint |
|
|
- Multipart WAV upload for cloning convenience |
|
|
- Configurable via env vars |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import base64 |
|
|
import hashlib |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import tempfile |
|
|
import threading |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import torch |
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _env_bool(name: str, default: bool = False) -> bool: |
|
|
v = os.getenv(name) |
|
|
if v is None: |
|
|
return default |
|
|
return v.strip().lower() in {"1", "true", "yes", "y", "on"} |
|
|
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float: |
|
|
v = os.getenv(name) |
|
|
return float(v) if v else default |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class Settings: |
|
|
|
|
|
model_name: str = os.getenv("XTTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2") |
|
|
default_language: str = os.getenv("XTTS_DEFAULT_LANGUAGE", "pl") |
|
|
default_speaker: str = os.getenv("XTTS_DEFAULT_SPEAKER", "Claribel Dervla") |
|
|
|
|
|
|
|
|
temperature: float = _env_float("XTTS_TEMPERATURE", 0.65) |
|
|
speed: float = _env_float("XTTS_SPEED", 1.0) |
|
|
top_p: float = _env_float("XTTS_TOP_P", 0.85) |
|
|
top_k: int = int(os.getenv("XTTS_TOP_K", "50")) |
|
|
repetition_penalty: float = _env_float("XTTS_REPETITION_PENALTY", 5.0) |
|
|
|
|
|
|
|
|
torch_compile: bool = _env_bool("XTTS_TORCH_COMPILE", False) |
|
|
use_fp16: bool = _env_bool("XTTS_USE_FP16", False) |
|
|
|
|
|
|
|
|
chunk_max_chars: int = int(os.getenv("CHUNK_MAX_CHARS", "250")) |
|
|
chunk_max_words: int = int(os.getenv("CHUNK_MAX_WORDS", "40")) |
|
|
join_silence_ms: int = int(os.getenv("JOIN_SILENCE_MS", "60")) |
|
|
|
|
|
|
|
|
speaker_cache_size: int = int(os.getenv("SPEAKER_CACHE_SIZE", "8")) |
|
|
|
|
|
|
|
|
num_threads: int = int(os.getenv("OMP_NUM_THREADS", "2")) |
|
|
|
|
|
|
|
|
S = Settings() |
|
|
|
|
|
|
|
|
torch.set_num_threads(S.num_threads) |
|
|
torch.set_num_interop_threads(max(1, S.num_threads // 2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SENT_SPLIT_RE = re.compile(r"(?<=[\.\!\?\:\;])\s+|\n+") |
|
|
_WS_RE = re.compile(r"\s+") |
|
|
|
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
|
return _WS_RE.sub(" ", text.strip()) |
|
|
|
|
|
|
|
|
def split_text_into_chunks( |
|
|
text: str, |
|
|
max_chars: int = S.chunk_max_chars, |
|
|
max_words: int = S.chunk_max_words, |
|
|
) -> List[str]: |
|
|
text = normalize_text(text) |
|
|
if not text: |
|
|
return [] |
|
|
|
|
|
sents = [s.strip() for s in _SENT_SPLIT_RE.split(text) if s.strip()] |
|
|
chunks: List[str] = [] |
|
|
cur: List[str] = [] |
|
|
cur_chars = 0 |
|
|
cur_words = 0 |
|
|
|
|
|
def flush(): |
|
|
nonlocal cur, cur_chars, cur_words |
|
|
if cur: |
|
|
chunks.append(" ".join(cur).strip()) |
|
|
cur, cur_chars, cur_words = [], 0, 0 |
|
|
|
|
|
for sent in sents: |
|
|
w = len(sent.split()) |
|
|
c = len(sent) |
|
|
if cur and (cur_chars + c > max_chars or cur_words + w > max_words): |
|
|
flush() |
|
|
cur.append(sent) |
|
|
cur_chars += c + 1 |
|
|
cur_words += w |
|
|
|
|
|
flush() |
|
|
return chunks |
|
|
|
|
|
|
|
|
def wav_bytes_from_audio(audio: np.ndarray, sr: int) -> bytes: |
|
|
buf = io.BytesIO() |
|
|
sf.write(buf, np.asarray(audio, dtype=np.float32), sr, format="WAV", subtype="PCM_16") |
|
|
return buf.getvalue() |
|
|
|
|
|
|
|
|
def concat_audio(chunks: List[np.ndarray], sr: int, silence_ms: int = S.join_silence_ms) -> np.ndarray: |
|
|
if not chunks: |
|
|
return np.zeros((1,), dtype=np.float32) |
|
|
if len(chunks) == 1: |
|
|
return np.asarray(chunks[0], dtype=np.float32) |
|
|
|
|
|
silence = np.zeros(int(sr * silence_ms / 1000), dtype=np.float32) if silence_ms > 0 else None |
|
|
parts = [] |
|
|
for i, ch in enumerate(chunks): |
|
|
parts.append(np.asarray(ch, dtype=np.float32)) |
|
|
if silence is not None and i < len(chunks) - 1: |
|
|
parts.append(silence) |
|
|
return np.concatenate(parts) |
|
|
|
|
|
|
|
|
def b64encode_bytes(b: bytes) -> str: |
|
|
return base64.b64encode(b).decode("ascii") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerCache: |
|
|
def __init__(self, maxsize: int = S.speaker_cache_size): |
|
|
self._cache: Dict[str, Tuple] = {} |
|
|
self._order: List[str] = [] |
|
|
self._maxsize = maxsize |
|
|
self._lock = threading.Lock() |
|
|
|
|
|
def _key(self, wav_bytes: bytes) -> str: |
|
|
return hashlib.sha256(wav_bytes).hexdigest()[:16] |
|
|
|
|
|
def get(self, wav_bytes: bytes) -> Optional[Tuple]: |
|
|
key = self._key(wav_bytes) |
|
|
with self._lock: |
|
|
return self._cache.get(key) |
|
|
|
|
|
def put(self, wav_bytes: bytes, latents: Tuple) -> None: |
|
|
key = self._key(wav_bytes) |
|
|
with self._lock: |
|
|
if key in self._cache: |
|
|
return |
|
|
if len(self._order) >= self._maxsize: |
|
|
evict = self._order.pop(0) |
|
|
self._cache.pop(evict, None) |
|
|
self._cache[key] = latents |
|
|
self._order.append(key) |
|
|
|
|
|
|
|
|
_speaker_cache = SpeakerCache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_model_lock = threading.Lock() |
|
|
_infer_lock = threading.Lock() |
|
|
|
|
|
_tts_model = None |
|
|
_tts_error: Optional[str] = None |
|
|
|
|
|
|
|
|
def _get_model(): |
|
|
global _tts_model, _tts_error |
|
|
if _tts_error is not None: |
|
|
return None |
|
|
if _tts_model is not None: |
|
|
return _tts_model |
|
|
|
|
|
with _model_lock: |
|
|
if _tts_error is not None: |
|
|
return None |
|
|
if _tts_model is not None: |
|
|
return _tts_model |
|
|
|
|
|
try: |
|
|
from TTS.api import TTS |
|
|
print(f"[XTTS] Loading {S.model_name} ...") |
|
|
t0 = time.time() |
|
|
tts = TTS(model_name=S.model_name, progress_bar=False, gpu=False) |
|
|
|
|
|
|
|
|
inner = getattr(getattr(tts, "synthesizer", None), "tts_model", None) |
|
|
if isinstance(inner, torch.nn.Module): |
|
|
if S.use_fp16: |
|
|
try: |
|
|
inner = inner.half() |
|
|
tts.synthesizer.tts_model = inner |
|
|
print("[XTTS] FP16 enabled") |
|
|
except Exception as e: |
|
|
print(f"[XTTS] FP16 failed: {e}") |
|
|
if S.torch_compile: |
|
|
try: |
|
|
inner = torch.compile(inner) |
|
|
tts.synthesizer.tts_model = inner |
|
|
print("[XTTS] torch.compile enabled") |
|
|
except Exception as e: |
|
|
print(f"[XTTS] torch.compile failed: {e}") |
|
|
|
|
|
_tts_model = tts |
|
|
print(f"[XTTS] Model loaded in {time.time() - t0:.1f}s") |
|
|
except Exception as e: |
|
|
_tts_error = str(e) |
|
|
print(f"[XTTS] FAILED to load: {e}") |
|
|
return None |
|
|
|
|
|
return _tts_model |
|
|
|
|
|
|
|
|
def _get_sample_rate() -> int: |
|
|
tts = _get_model() |
|
|
if tts is None: |
|
|
return 22050 |
|
|
synth = getattr(tts, "synthesizer", None) |
|
|
return getattr(synth, "output_sample_rate", 22050) if synth else 22050 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _synthesize(text: str, language: str, speaker_wav_bytes: Optional[bytes] = None) -> Tuple[np.ndarray, int, float]: |
|
|
"""Returns (audio_np, sample_rate, generation_time_s).""" |
|
|
tts = _get_model() |
|
|
if tts is None: |
|
|
raise HTTPException(503, f"XTTS unavailable: {_tts_error or 'model not loaded'}") |
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
with _infer_lock: |
|
|
tmp_path = None |
|
|
try: |
|
|
if speaker_wav_bytes: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
tmp.write(speaker_wav_bytes) |
|
|
tmp.flush() |
|
|
tmp_path = tmp.name |
|
|
|
|
|
audio_np = tts.tts( |
|
|
text=text, |
|
|
language=language, |
|
|
speaker_wav=tmp_path, |
|
|
) |
|
|
else: |
|
|
|
|
|
audio_np = tts.tts( |
|
|
text=text, |
|
|
language=language, |
|
|
speaker=S.default_speaker, |
|
|
) |
|
|
finally: |
|
|
if tmp_path: |
|
|
try: |
|
|
os.remove(tmp_path) |
|
|
except OSError: |
|
|
pass |
|
|
|
|
|
sr = _get_sample_rate() |
|
|
gen_time = time.time() - t0 |
|
|
return np.asarray(audio_np, dtype=np.float32), sr, gen_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthRequest(BaseModel): |
|
|
text: str = Field(..., min_length=1, max_length=5000) |
|
|
language: Optional[str] = Field(None, description="Language code (default: pl)") |
|
|
speaker_wav_b64: Optional[str] = Field(None, description="Base64-encoded WAV for voice cloning") |
|
|
|
|
|
|
|
|
class StreamRequest(BaseModel): |
|
|
text: str = Field(..., min_length=1, max_length=5000) |
|
|
language: Optional[str] = None |
|
|
speaker_wav_b64: Optional[str] = None |
|
|
|
|
|
|
|
|
class AudioResponse(BaseModel): |
|
|
audio_b64: str |
|
|
sample_rate: int |
|
|
duration_s: float |
|
|
generation_time_s: float |
|
|
text: str |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str = "ok" |
|
|
version: str = "2.0.0" |
|
|
model: str = S.model_name |
|
|
language: str = S.default_language |
|
|
xtts_available: bool = True |
|
|
speaker_cache_size: int = S.speaker_cache_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Forge-TTS API", version="2.0.0") |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
def health(): |
|
|
available = _tts_error is None |
|
|
return HealthResponse( |
|
|
xtts_available=available, |
|
|
status="ok" if available else f"degraded: {_tts_error}", |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/xtts/synthesize", response_model=AudioResponse) |
|
|
def xtts_synthesize(req: SynthRequest): |
|
|
speaker_bytes = None |
|
|
if req.speaker_wav_b64: |
|
|
try: |
|
|
speaker_bytes = base64.b64decode(req.speaker_wav_b64) |
|
|
except Exception as e: |
|
|
raise HTTPException(400, f"Invalid base64 speaker_wav: {e}") |
|
|
|
|
|
lang = req.language or S.default_language |
|
|
chunks = split_text_into_chunks(req.text) |
|
|
if not chunks: |
|
|
raise HTTPException(400, "Empty text after normalization") |
|
|
|
|
|
audio_parts = [] |
|
|
total_gen = 0.0 |
|
|
sr = 22050 |
|
|
|
|
|
for chunk_text in chunks: |
|
|
audio, sr, gen_t = _synthesize(chunk_text, lang, speaker_bytes) |
|
|
audio_parts.append(audio) |
|
|
total_gen += gen_t |
|
|
|
|
|
full_audio = concat_audio(audio_parts, sr) |
|
|
wav_bytes = wav_bytes_from_audio(full_audio, sr) |
|
|
|
|
|
return AudioResponse( |
|
|
audio_b64=b64encode_bytes(wav_bytes), |
|
|
sample_rate=sr, |
|
|
duration_s=round(len(full_audio) / sr, 3), |
|
|
generation_time_s=round(total_gen, 3), |
|
|
text=req.text, |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/xtts/stream") |
|
|
async def xtts_stream(req: StreamRequest): |
|
|
speaker_bytes = None |
|
|
if req.speaker_wav_b64: |
|
|
try: |
|
|
speaker_bytes = base64.b64decode(req.speaker_wav_b64) |
|
|
except Exception as e: |
|
|
raise HTTPException(400, f"Invalid base64: {e}") |
|
|
|
|
|
chunks = split_text_into_chunks(req.text) |
|
|
if not chunks: |
|
|
raise HTTPException(400, "Empty text after chunking") |
|
|
|
|
|
lang = req.language or S.default_language |
|
|
|
|
|
async def generate(): |
|
|
for i, chunk_text in enumerate(chunks): |
|
|
try: |
|
|
audio, sr, gen_t = await asyncio.to_thread( |
|
|
_synthesize, chunk_text, lang, speaker_bytes |
|
|
) |
|
|
wav_bytes = wav_bytes_from_audio(audio, sr) |
|
|
payload = { |
|
|
"chunk_index": i, |
|
|
"total_chunks": len(chunks), |
|
|
"text": chunk_text, |
|
|
"audio_b64": b64encode_bytes(wav_bytes), |
|
|
"sample_rate": sr, |
|
|
"generation_time_s": round(gen_t, 3), |
|
|
} |
|
|
yield f"data: {json.dumps(payload)}\n\n" |
|
|
except Exception as e: |
|
|
yield f"data: {json.dumps({'error': str(e), 'chunk_index': i})}\n\n" |
|
|
break |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
|
|
|
|
|
|
@app.post("/v1/xtts/clone", response_model=AudioResponse) |
|
|
async def xtts_clone( |
|
|
text: str = Form(..., min_length=1, max_length=5000), |
|
|
language: str = Form(default=S.default_language), |
|
|
speaker_wav: UploadFile = File(..., description="WAV file for voice cloning"), |
|
|
): |
|
|
"""Convenience endpoint: multipart form with WAV file upload (not base64).""" |
|
|
wav_bytes = await speaker_wav.read() |
|
|
if len(wav_bytes) < 44: |
|
|
raise HTTPException(400, "WAV file too small or empty") |
|
|
if len(wav_bytes) > 10 * 1024 * 1024: |
|
|
raise HTTPException(400, "WAV file too large (max 10MB)") |
|
|
|
|
|
chunks = split_text_into_chunks(text) |
|
|
if not chunks: |
|
|
raise HTTPException(400, "Empty text after normalization") |
|
|
|
|
|
audio_parts = [] |
|
|
total_gen = 0.0 |
|
|
sr = 22050 |
|
|
|
|
|
for chunk_text in chunks: |
|
|
audio, sr, gen_t = _synthesize(chunk_text, language, wav_bytes) |
|
|
audio_parts.append(audio) |
|
|
total_gen += gen_t |
|
|
|
|
|
full_audio = concat_audio(audio_parts, sr) |
|
|
wav_out = wav_bytes_from_audio(full_audio, sr) |
|
|
|
|
|
return AudioResponse( |
|
|
audio_b64=b64encode_bytes(wav_out), |
|
|
sample_rate=sr, |
|
|
duration_s=round(len(full_audio) / sr, 3), |
|
|
generation_time_s=round(total_gen, 3), |
|
|
text=text, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
print("\n" + "=" * 60) |
|
|
print("Forge-TTS v2.0.0 — XTTS-v2 Only") |
|
|
print("=" * 60) |
|
|
print(f"Model: {S.model_name}") |
|
|
print(f"Language: {S.default_language}") |
|
|
print(f"Threads: {S.num_threads}") |
|
|
print(f"FP16: {S.use_fp16}") |
|
|
print(f"Compile: {S.torch_compile}") |
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
|
|
|
_get_model() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|