Foradc
feat: add Voxtral TTS engine + French prosody preprocessing
737927a
#!/usr/bin/env python3
"""
TTS Demo Server
Qwen3-TTS (Clone / Custom / VoiceDesign) → local RTX 3090 via faster_qwen3_tts
Kokoro FR → local (hexgrad/Kokoro-82M)
F5-TTS FR → local (RASPIAUDIO checkpoint)
Chatterbox → local (ResembleAI)
Fish-Speech → local (fishaudio/fish-speech-1.5)
"""
import argparse
import asyncio
import base64
from collections import OrderedDict
import hashlib
import io
import json
import os
import re
import sys
import tempfile
import threading
import time
from pathlib import Path
from typing import Optional
import numpy as np
import soundfile as sf
import torch
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
# ── Voxtral TTS (vLLM-Omni, Mistral AI) ──────────────────────────────────────
_VOXTRAL_URL = os.environ.get("VOXTRAL_URL", "http://localhost:8000")
_VOXTRAL_MODEL = "mistralai/Voxtral-4B-TTS-2603"
# ── Fish-Speech ───────────────────────────────────────────────────────────────
FISH_SPEECH_REPO = Path("/tmp/fish-speech")
FISH_SPEECH_MODEL = Path("/root/fish-speech-model")
# Patch torchaudio nightly: list_audio_backends removed in recent builds
import torchaudio as _torchaudio
if not hasattr(_torchaudio, "list_audio_backends"):
_torchaudio.list_audio_backends = lambda: ["ffmpeg", "sox"]
_fish_engine = None
_fish_lock = threading.Lock()
def _get_fish_engine():
global _fish_engine
if _fish_engine is not None:
return _fish_engine
with _fish_lock:
if _fish_engine is not None:
return _fish_engine
if not FISH_SPEECH_REPO.exists():
raise RuntimeError("Fish-Speech repo not found at /tmp/fish-speech. Run: git clone https://github.com/fishaudio/fish-speech /tmp/fish-speech && cd /tmp/fish-speech && git checkout v1.5.1")
if not FISH_SPEECH_MODEL.exists():
raise RuntimeError("Fish-Speech model not found at /root/fish-speech-model. Download with: python3 -c \"from huggingface_hub import snapshot_download; snapshot_download('fishaudio/fish-speech-1.5', local_dir='/root/fish-speech-model')\"")
sys.path.insert(0, str(FISH_SPEECH_REPO))
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.models.vqgan.inference import load_model as load_decoder_model
from fish_speech.inference_engine import TTSInferenceEngine
import torch as _torch
device = "cuda" if _torch.cuda.is_available() else "cpu"
precision = _torch.bfloat16
use_compile = device == "cuda" # torch.compile only on GPU
llama_queue = launch_thread_safe_queue(
checkpoint_path=str(FISH_SPEECH_MODEL),
device=device,
precision=precision,
compile=use_compile,
)
decoder_model = load_decoder_model(
config_name="firefly_gan_vq",
checkpoint_path=str(FISH_SPEECH_MODEL / "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"),
device=device,
)
_fish_engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
precision=precision,
compile=use_compile,
)
return _fish_engine
# ── Enabled engines (CPU/GPU gating) ─────────────────────────────────────────
_ENABLED_ENV = os.environ.get("ENABLED_ENGINES", "all").lower()
ENABLED_ENGINES: set[str] | None = (
None if _ENABLED_ENV == "all"
else {e.strip() for e in _ENABLED_ENV.split(",") if e.strip()}
)
def _engine_enabled(name: str) -> bool:
return ENABLED_ENGINES is None or name in ENABLED_ENGINES
# ── Kokoro TTS (multilingual) ─────────────────────────────────────────────────
_kokoro_pipelines: dict[str, object] = {}
_kokoro_lock = threading.Lock()
KOKORO_VOICES_FR = {
# French
"ff_siwis": "Siwis — FR Femme ★",
# American English — Female
"af_heart": "Heart — EN Femme ★",
"af_bella": "Bella — EN Femme",
"af_nicole": "Nicole — EN Femme (ASMR)",
"af_sarah": "Sarah — EN Femme",
"af_sky": "Sky — EN Femme",
# American English — Male
"am_echo": "Echo — EN Homme",
"am_michael": "Michael — EN Homme",
"am_adam": "Adam — EN Homme",
# British English — Female
"bf_emma": "Emma — EN(UK) Femme",
"bf_isabella":"Isabella — EN(UK) Femme",
# British English — Male
"bm_george": "George — EN(UK) Homme",
"bm_lewis": "Lewis — EN(UK) Homme",
}
_VOICE_LANG_CODE = {
"ff": "f", # French female
"af": "a", # American English female
"am": "a", # American English male
"bf": "b", # British English female
"bm": "b", # British English male
}
def _get_kokoro(voice: str = "ff_siwis"):
prefix = voice[:2] if len(voice) >= 2 else "ff"
lang_code = _VOICE_LANG_CODE.get(prefix, "f")
with _kokoro_lock:
if lang_code not in _kokoro_pipelines:
from kokoro import KPipeline
_kokoro_pipelines[lang_code] = KPipeline(lang_code=lang_code)
return _kokoro_pipelines[lang_code]
# ── Chatterbox TTS (ResembleAI) ───────────────────────────────────────────────
_chatterbox_model = None
_chatterbox_lock = threading.Lock()
def _get_chatterbox():
global _chatterbox_model
if _chatterbox_model is None:
with _chatterbox_lock:
if _chatterbox_model is None:
from chatterbox.tts import ChatterboxTTS
_device = "cuda" if torch.cuda.is_available() else "cpu"
_chatterbox_model = ChatterboxTTS.from_pretrained(device=_device)
return _chatterbox_model
# ── F5-TTS French (RASPIAUDIO checkpoint) ────────────────────────────────────
_f5_model = None
_f5_lock = threading.Lock()
F5_REPO = "RASPIAUDIO/F5-French-MixedSpeakers-reduced"
def _get_f5():
global _f5_model
if _f5_model is None:
with _f5_lock:
if _f5_model is None:
from f5_tts.api import F5TTS
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download(F5_REPO, "model_last_reduced.pt")
vocab = hf_hub_download(F5_REPO, "vocab.txt")
_f5_model = F5TTS(model="F5TTS_v1_Base", ckpt_file=ckpt, vocab_file=vocab)
return _f5_model
# ── Qwen3-TTS (local) ────────────────────────────────────────────────────────
sys.path.insert(0, str(Path(__file__).parent.parent))
try:
from faster_qwen3_tts import FasterQwen3TTS
except ImportError:
print("Warning: faster_qwen3_tts not found — Qwen3 engine disabled. Install with: pip install faster-qwen3-tts")
FasterQwen3TTS = None # type: ignore
try:
from nano_parakeet import from_pretrained as _parakeet_from_pretrained
except ImportError:
_parakeet_from_pretrained = None
_ALL_MODELS = [
"Qwen/Qwen3-TTS-12Hz-0.6B-Base",
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
"Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
]
_active_models_env = os.environ.get("ACTIVE_MODELS", "")
if _active_models_env:
_allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
AVAILABLE_MODELS = [m for m in _ALL_MODELS if m in _allowed]
else:
AVAILABLE_MODELS = list(_ALL_MODELS)
BASE_DIR = Path(__file__).resolve().parent
_ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
PRESET_REFS = [
("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
]
_GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main"
_PRESET_REMOTE = {
"ref_audio": f"{_GITHUB_RAW}/ref_audio.wav",
"ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav",
"ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav",
}
_TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt"
def _fetch_preset_assets():
import urllib.request
_ASSET_DIR.mkdir(parents=True, exist_ok=True)
PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True)
if not PRESET_TRANSCRIPTS.exists():
try:
urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS)
except Exception as e:
print(f"Warning: could not fetch transcripts: {e}")
for key, path, _ in PRESET_REFS:
if not path.exists() and key in _PRESET_REMOTE:
try:
urllib.request.urlretrieve(_PRESET_REMOTE[key], path)
print(f"Downloaded {path.name}")
except Exception as e:
print(f"Warning: could not fetch {key}: {e}")
_preset_refs: dict[str, dict] = {}
def _load_preset_transcripts():
if not PRESET_TRANSCRIPTS.exists():
return {}
transcripts = {}
for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines():
if ":" not in line:
continue
key_part, text = line.split(":", 1)
key = key_part.split("(")[0].strip()
transcripts[key] = text.strip()
return transcripts
def _load_preset_refs():
transcripts = _load_preset_transcripts()
for key, path, label in PRESET_REFS:
if not path.exists():
continue
content = path.read_bytes()
cached_path = _get_cached_ref_path(content)
_preset_refs[key] = {
"id": key, "label": label, "filename": path.name,
"path": cached_path, "ref_text": transcripts.get(key, ""),
"audio_b64": base64.b64encode(content).decode(),
}
def _prime_preset_voice_cache(model):
if not _preset_refs:
return
for preset in _preset_refs.values():
for xvec_only in (True, False):
try:
model._prepare_generation(
text="Hello.", ref_audio=preset["path"], ref_text=preset["ref_text"],
language="English", xvec_only=xvec_only, non_streaming_mode=True,
)
except Exception:
continue
app = FastAPI(title="Faster Qwen3-TTS Demo")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
_model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict()
_model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "2"))
_active_model_name: str | None = None
_loading = False
_ref_cache: dict[str, str] = {}
_ref_cache_lock = threading.Lock()
_parakeet = None
_generation_lock = asyncio.Lock()
_generation_waiters: int = 0
MAX_TEXT_CHARS = 15000
MAX_AUDIO_BYTES = 10 * 1024 * 1024
_AUDIO_TOO_LARGE_MSG = (
"Audio file too large ({size_mb:.1f} MB). "
"Please upload a shorter recording (under 1 minute)."
)
# ─── French TTS preprocessing ─────────────────────────────────────────────────
FRENCH_ABBREVS = [
(r'\bM\.\s+', 'Monsieur '),
(r'\bMme\.?\s+', 'Madame '),
(r'\bMlle\.?\s+', 'Mademoiselle '),
(r'\bDr\.?\s+', 'Docteur '),
(r'\bPr\.?\s+', 'Professeur '),
(r'\bSt\.?\s+', 'Saint '),
(r'\betc\.(?!\w)', 'et cetera'),
(r'\bn°\s*(\d+)', r'numéro \1'),
(r'\b(\d{1,2})\s*h\s*(\d{2})\b', r'\1 heures \2'),
(r'\b(\d{1,2})\s*h\b', r'\1 heures'),
(r'\bp\.\s*(\d+)\b', r'page \1'),
]
FRENCH_NARRATOR_PROMPT = (
"Narrateur professionnel de livres audio français, voix grave, chaude et captivante. "
"Débit naturellement mesuré — ni précipité ni traînant. "
"Respectez scrupuleusement la ponctuation : "
"légère pause aux virgules, souffle marqué aux points, "
"pause longue et respirée aux doubles sauts de paragraphe. "
"Aux guillemets « », adoptez un registre légèrement plus direct et personnel pour le dialogue, "
"puis revenez au ton narratif après le ». "
"Les points de suspension (...) et les tirets (—) appellent une vraie pause respiratoire. "
"Ton légèrement plus grave et plus riche que la conversation ordinaire, "
"comme un conteur au coin du feu. "
"Restez cohérent du début à la fin — même timbre, même rythme de fond."
)
def preprocess_french(text: str) -> str:
"""French TTS prosody preprocessing (arXiv:2508.17494).
Expands abbreviations and inserts natural pause markers for the TTS model.
"""
for pattern, repl in FRENCH_ABBREVS:
text = re.sub(pattern, repl, text)
# Guillemet spacing
text = re.sub(r'«\s*', '« ', text)
text = re.sub(r'\s*»', ' »', text)
# Paragraph breaks → strong pause
text = re.sub(r'\n\s*\n', ' ... ', text)
# Em-dash → spaced pause
text = re.sub(r'\s*—\s*', ' — ', text)
# Normalize ellipsis
text = re.sub(r'\.{3,}', '...', text)
# Normalize whitespace first, then add post-sentence double space
text = re.sub(r'[ \t]+', ' ', text)
text = re.sub(r'([!?])', r'\1 ', text)
return text.strip()
# ─── Helpers ──────────────────────────────────────────────────────────────────
def _to_wav_b64(audio: np.ndarray, sr: int) -> str:
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
if audio.ndim > 1:
audio = audio.squeeze()
buf = io.BytesIO()
sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
return base64.b64encode(buf.getvalue()).decode()
def _concat_audio(audio_list) -> np.ndarray:
if isinstance(audio_list, np.ndarray):
return audio_list.astype(np.float32).squeeze()
parts = [np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0]
return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32)
def _get_cached_ref_path(content: bytes) -> str:
digest = hashlib.sha1(content).hexdigest()
with _ref_cache_lock:
cached = _ref_cache.get(digest)
if cached and os.path.exists(cached):
return cached
path = Path(tempfile.gettempdir()) / f"qwen3tts_ref_{digest}.wav"
if not path.exists():
path.write_bytes(content)
_ref_cache[digest] = str(path)
return str(path)
# ─── Routes ───────────────────────────────────────────────────────────────────
_fetch_preset_assets()
_load_preset_refs()
@app.get("/")
async def root():
return FileResponse(Path(__file__).parent / "index.html")
@app.post("/transcribe")
async def transcribe_audio(audio: UploadFile = File(...)):
if _parakeet is None:
raise HTTPException(status_code=503, detail="Transcription model not loaded")
content = await audio.read()
if len(content) > MAX_AUDIO_BYTES:
raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024))
def run():
import torchaudio
wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False)
if wav.ndim > 1:
wav = wav.mean(axis=1)
wav_t = torch.from_numpy(wav)
if sr != 16000:
wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0)
return _parakeet.transcribe(wav_t.cuda())
text = await asyncio.to_thread(run)
return {"text": text}
@app.get("/status")
async def get_status():
speakers = []
model_type = None
active = _model_cache.get(_active_model_name) if _active_model_name else None
if active is not None:
try:
model_type = active.model.model.tts_model_type
speakers = active.model.get_supported_speakers() or []
except Exception:
speakers = []
return {
"loaded": active is not None,
"model": _active_model_name,
"loading": _loading,
"available_models": AVAILABLE_MODELS,
"model_type": model_type,
"speakers": speakers,
"transcription_available": _parakeet is not None,
"preset_refs": [{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values()],
"queue_depth": _generation_waiters,
"cached_models": list(_model_cache.keys()),
"kokoro_voices": KOKORO_VOICES_FR,
"voxtral_url": _VOXTRAL_URL,
"voxtral_model": _VOXTRAL_MODEL,
}
@app.post("/load")
async def load_model(model_id: str = Form(...)):
if not _engine_enabled("qwen3") or FasterQwen3TTS is None:
raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.")
global _active_model_name, _loading
if model_id in _model_cache:
_active_model_name = model_id
_model_cache.move_to_end(model_id)
return {"status": "already_loaded", "model": model_id}
_loading = True
def _do_load():
global _active_model_name, _loading
try:
if len(_model_cache) >= _model_cache_max:
evicted, _ = _model_cache.popitem(last=False)
print(f"Model cache full — evicted: {evicted}")
new_model = FasterQwen3TTS.from_pretrained(model_id, device="cuda", dtype=torch.bfloat16)
print("Capturing CUDA graphs…")
new_model._warmup(prefill_len=100)
_model_cache[model_id] = new_model
_model_cache.move_to_end(model_id)
_active_model_name = model_id
_prime_preset_voice_cache(new_model)
print("CUDA graphs captured — model ready.")
finally:
_loading = False
async with _generation_lock:
await asyncio.to_thread(_do_load)
return {"status": "loaded", "model": model_id}
@app.post("/generate/stream")
async def generate_stream(
text: str = Form(...),
language: str = Form("English"),
mode: str = Form("voice_clone"),
ref_text: str = Form(""),
speaker: str = Form(""),
instruct: str = Form(""),
xvec_only: bool = Form(True),
chunk_size: int = Form(8),
temperature: float = Form(0.7),
top_k: int = Form(30),
repetition_penalty: float = Form(1.1),
ref_preset: str = Form(""),
ref_audio: UploadFile = File(None),
seed: Optional[int] = Form(None),
):
if not _engine_enabled("qwen3") or FasterQwen3TTS is None:
raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.")
if not _active_model_name or _active_model_name not in _model_cache:
raise HTTPException(status_code=400, detail="Modèle non chargé. Cliquez sur 'Load' d'abord.")
if len(text) > MAX_TEXT_CHARS:
raise HTTPException(status_code=400, detail=f"Texte trop long ({len(text)} chars). Max {MAX_TEXT_CHARS}.")
tmp_path = None
tmp_is_cached = False
if ref_preset and ref_preset in _preset_refs:
preset = _preset_refs[ref_preset]
tmp_path = preset["path"]
tmp_is_cached = True
if not ref_text:
ref_text = preset["ref_text"]
elif ref_audio and ref_audio.filename:
content = await ref_audio.read()
if len(content) > MAX_AUDIO_BYTES:
raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024))
tmp_path = _get_cached_ref_path(content)
tmp_is_cached = True
loop = asyncio.get_event_loop()
queue: asyncio.Queue[str | None] = asyncio.Queue()
def run_generation():
try:
model = _model_cache.get(_active_model_name)
if model is None:
raise RuntimeError("No model loaded.")
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
t0 = time.perf_counter()
total_audio_s = 0.0
voice_clone_ms = 0.0
if mode == "voice_clone":
gen = model.generate_voice_clone_streaming(
text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature,
top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800,
)
elif mode == "custom":
if not speaker:
raise ValueError("Speaker ID required for custom voice")
gen = model.generate_custom_voice_streaming(
text=text, speaker=speaker, language=language, instruct=instruct,
chunk_size=chunk_size, temperature=temperature, top_k=top_k,
repetition_penalty=repetition_penalty, max_new_tokens=1800,
)
else:
gen = model.generate_voice_design_streaming(
text=text, instruct=instruct, language=language, chunk_size=chunk_size,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty,
max_new_tokens=1800,
)
ttfa_ms = None
total_gen_ms = 0.0
first_audio = next(gen, None)
if first_audio is not None:
audio_chunk, sr, timing = first_audio
wall_first_ms = (time.perf_counter() - t0) * 1000
model_ms = timing.get("prefill_ms", 0) + timing.get("decode_ms", 0)
voice_clone_ms = max(0.0, wall_first_ms - model_ms)
total_gen_ms += timing.get('prefill_ms', 0) + timing.get('decode_ms', 0)
if ttfa_ms is None:
ttfa_ms = total_gen_ms
audio_chunk = _concat_audio(audio_chunk)
dur = len(audio_chunk) / sr
total_audio_s += dur
rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
"type": "chunk", "audio_b64": _to_wav_b64(audio_chunk, sr),
"sample_rate": sr, "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms),
"rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
"elapsed_ms": round(time.perf_counter() - t0, 3) * 1000,
}))
for audio_chunk, sr, timing in gen:
total_gen_ms += timing.get('prefill_ms', 0) + timing.get('decode_ms', 0)
if ttfa_ms is None:
ttfa_ms = total_gen_ms
audio_chunk = _concat_audio(audio_chunk)
dur = len(audio_chunk) / sr
total_audio_s += dur
rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
"type": "chunk", "audio_b64": _to_wav_b64(audio_chunk, sr),
"sample_rate": sr, "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms),
"rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
"elapsed_ms": round(time.perf_counter() - t0, 3) * 1000,
}))
rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
"type": "done", "ttfa_ms": round(ttfa_ms) if ttfa_ms else 0,
"voice_clone_ms": round(voice_clone_ms), "rtf": round(rtf, 3),
"total_audio_s": round(total_audio_s, 3),
"total_ms": round((time.perf_counter() - t0) * 1000),
}))
except Exception as e:
import traceback
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
"type": "error", "message": str(e), "detail": traceback.format_exc()
}))
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached:
os.unlink(tmp_path)
async def sse():
global _generation_waiters
lock_acquired = False
_generation_waiters += 1
people_ahead = _generation_waiters - 1 + (1 if _generation_lock.locked() else 0)
try:
if people_ahead > 0:
yield f"data: {json.dumps({'type': 'queued', 'position': people_ahead})}\n\n"
await _generation_lock.acquire()
lock_acquired = True
_generation_waiters -= 1
threading.Thread(target=run_generation, daemon=True).start()
while True:
msg = await queue.get()
if msg is None:
break
yield f"data: {msg}\n\n"
except asyncio.CancelledError:
pass
finally:
if lock_acquired:
_generation_lock.release()
else:
_generation_waiters -= 1
return StreamingResponse(sse(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.post("/generate")
async def generate_non_streaming(
text: str = Form(...),
language: str = Form("English"),
mode: str = Form("voice_clone"),
ref_text: str = Form(""),
speaker: str = Form(""),
instruct: str = Form(""),
xvec_only: bool = Form(True),
temperature: float = Form(0.7),
top_k: int = Form(30),
repetition_penalty: float = Form(1.1),
ref_preset: str = Form(""),
ref_audio: UploadFile = File(None),
seed: Optional[int] = Form(None),
):
if not _engine_enabled("qwen3") or FasterQwen3TTS is None:
raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.")
if not _active_model_name or _active_model_name not in _model_cache:
raise HTTPException(status_code=400, detail="Modèle non chargé.")
if len(text) > MAX_TEXT_CHARS:
raise HTTPException(status_code=400, detail=f"Texte trop long ({len(text)} chars).")
tmp_path = None
tmp_is_cached = False
if ref_preset and ref_preset in _preset_refs:
preset = _preset_refs[ref_preset]
tmp_path = preset["path"]
tmp_is_cached = True
if not ref_text:
ref_text = preset["ref_text"]
elif ref_audio and ref_audio.filename:
content = await ref_audio.read()
if len(content) > MAX_AUDIO_BYTES:
raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024))
tmp_path = _get_cached_ref_path(content)
tmp_is_cached = True
def run():
model = _model_cache.get(_active_model_name)
if model is None:
raise RuntimeError("No model loaded.")
if seed is not None:
# Full determinism: seed all RNG sources used by faster-qwen3-tts
import random as _rnd
_rnd.seed(seed)
np.random.seed(seed % (2**31))
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# cuDNN deterministic mode: same kernel algo across calls → same output
_prev_det = torch.backends.cudnn.deterministic
_prev_bench = torch.backends.cudnn.benchmark
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
_prev_det = _prev_bench = None
t0 = time.perf_counter()
try:
if mode == "voice_clone":
audio_list, sr = model.generate_voice_clone(
text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
xvec_only=xvec_only, temperature=temperature, top_k=top_k,
repetition_penalty=repetition_penalty, max_new_tokens=1800,
)
elif mode == "custom":
if not speaker:
raise ValueError("Speaker ID required")
audio_list, sr = model.generate_custom_voice(
text=text, speaker=speaker, language=language, instruct=instruct,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800,
)
else:
audio_list, sr = model.generate_voice_design(
text=text, instruct=instruct, language=language,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800,
)
finally:
# Restore cuDNN settings after generation
if _prev_det is not None:
torch.backends.cudnn.deterministic = _prev_det
torch.backends.cudnn.benchmark = _prev_bench
elapsed = time.perf_counter() - t0
audio = _concat_audio(audio_list)
dur = len(audio) / sr
return audio, sr, elapsed, dur
global _generation_waiters
_generation_waiters += 1
lock_acquired = False
try:
await _generation_lock.acquire()
lock_acquired = True
_generation_waiters -= 1
audio, sr, elapsed, dur = await asyncio.to_thread(run)
rtf = dur / elapsed if elapsed > 0 else 0.0
return JSONResponse({
"audio_b64": _to_wav_b64(audio, sr),
"sample_rate": sr,
"metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)},
})
finally:
if lock_acquired:
_generation_lock.release()
else:
_generation_waiters -= 1
if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached:
os.unlink(tmp_path)
@app.post("/generate/kokoro_fr")
async def generate_kokoro_fr(
text: str = Form(...),
voice: str = Form("ff_siwis"),
speed: float = Form(1.0),
):
if not _engine_enabled("kokoro"):
raise HTTPException(status_code=503, detail="Kokoro engine not enabled on this server.")
if voice not in KOKORO_VOICES_FR:
voice = "ff_siwis"
speed = max(0.5, min(2.0, speed))
def _run():
pipeline = _get_kokoro(voice)
chunks = []
for _gs, _ps, audio in pipeline(text, voice=voice, speed=speed):
chunks.append(audio.numpy() if hasattr(audio, "numpy") else audio)
if not chunks:
raise ValueError("Kokoro: no audio generated")
combined = np.concatenate(chunks)
buf = io.BytesIO()
sf.write(buf, combined, 24000, format="WAV")
buf.seek(0)
return buf.read()
try:
wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/f5_fr")
async def generate_f5_fr(
text: str = Form(...),
ref_wav: UploadFile = File(None),
ref_text: str = Form(""),
speed: float = Form(1.0),
nfe_step: int = Form(32),
cross_fade_duration: float = Form(0.15),
seed: Optional[int] = Form(None),
):
if not _engine_enabled("f5"):
raise HTTPException(status_code=503, detail="F5-TTS engine not enabled on this server.")
ref_path = None
cleanup_ref = False
if ref_wav and ref_wav.filename:
ref_bytes = await ref_wav.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(ref_bytes)
ref_path = tmp.name
cleanup_ref = True
else:
# Try local narrator ref, then bundled F5-TTS ref (check all Python versions)
import glob as _glob
bundled_candidates = _glob.glob(
"/usr/local/lib/python3*/dist-packages/f5_tts/infer/examples/basic/basic_ref_en.wav"
)
for candidate in [
Path(__file__).parent / "narrator_ref.wav",
] + [Path(p) for p in sorted(bundled_candidates, reverse=True)]:
if candidate.exists():
ref_path = str(candidate)
if "basic_ref_en" in ref_path and not ref_text:
ref_text = "Some call me nature, others call me mother nature."
break
if ref_path is None:
raise HTTPException(
status_code=400,
detail="F5-TTS nécessite un fichier audio de référence. Veuillez en uploader un via 'Réf audio'."
)
def _run():
model = _get_f5()
out_tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
out_path = out_tmp.name
out_tmp.close()
try:
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model.infer(
ref_file=ref_path, ref_text=ref_text or "", gen_text=text,
file_wave=out_path,
speed=max(0.5, min(2.0, speed)),
nfe_step=max(8, min(64, nfe_step)),
cross_fade_duration=max(0.0, min(0.5, cross_fade_duration)),
)
with open(out_path, "rb") as f:
return f.read()
finally:
if os.path.exists(out_path):
os.unlink(out_path)
if cleanup_ref and ref_path and os.path.exists(ref_path):
os.unlink(ref_path)
try:
wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/chatterbox")
async def generate_chatterbox(
text: str = Form(...),
ref_wav: UploadFile = File(None),
exaggeration: float = Form(0.5),
cfg_weight: float = Form(0.5),
temperature: float = Form(0.8),
seed: Optional[int] = Form(None),
):
if not _engine_enabled("chatterbox"):
raise HTTPException(status_code=503, detail="Chatterbox engine not enabled on this server.")
ref_path = None
cleanup_ref = False
if ref_wav and ref_wav.filename:
ref_bytes = await ref_wav.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(ref_bytes)
ref_path = tmp.name
cleanup_ref = True
def _run():
import torchaudio
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model = _get_chatterbox()
kwargs = dict(audio_prompt_path=ref_path, exaggeration=exaggeration, cfg_weight=cfg_weight)
try:
wav = model.generate(text, temperature=max(0.1, min(2.0, temperature)), **kwargs)
except TypeError:
wav = model.generate(text, **kwargs) # fallback if temperature not supported
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as out_tmp:
out_path = out_tmp.name
try:
torchaudio.save(out_path, wav, model.sr)
with open(out_path, "rb") as f:
return f.read()
finally:
if os.path.exists(out_path):
os.unlink(out_path)
try:
wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if cleanup_ref and ref_path and os.path.exists(ref_path):
os.unlink(ref_path)
def _fish_split_text(text: str, max_words: int = 200) -> list[str]:
"""Split long text into paragraph groups to avoid truncation and improve quality."""
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
if len(paragraphs) <= 1:
return [text]
chunks, current, current_words = [], [], 0
for p in paragraphs:
words = len(p.split())
if current_words + words > max_words and current:
chunks.append("\n\n".join(current))
current, current_words = [p], words
else:
current.append(p)
current_words += words
if current:
chunks.append("\n\n".join(current))
return chunks if len(chunks) > 1 else [text]
@app.post("/generate/fish")
async def generate_fish(
text: str = Form(...),
ref_wav: UploadFile = File(None),
ref_text: str = Form(""),
prev_wav: UploadFile = File(None), # rolling reference: last N seconds of previous chunk
temperature: float = Form(0.8),
top_p: float = Form(0.8),
repetition_penalty: float = Form(1.1),
max_new_tokens: int = Form(1024),
chunk_length: int = Form(200),
latency: str = Form("normal"),
normalize: bool = Form(True),
seed: Optional[int] = Form(None),
auto_split: bool = Form(False),
rolling_ref_secs: float = Form(6.0), # seconds to extract from tail of output for next ref
):
if not _engine_enabled("fish"):
raise HTTPException(status_code=503, detail="Fish-Speech engine not enabled on this server.")
ref_path = None
prev_path = None
cleanup_ref = False
cleanup_prev = False
if ref_wav and ref_wav.filename:
ref_bytes = await ref_wav.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(ref_bytes)
ref_path = tmp.name
cleanup_ref = True
# Rolling reference: previous chunk's tail audio (takes priority over static ref)
if prev_wav and prev_wav.filename:
prev_bytes = await prev_wav.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(prev_bytes)
prev_path = tmp.name
cleanup_prev = True
def _run():
sys.path.insert(0, str(FISH_SPEECH_REPO))
from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio
engine = _get_fish_engine()
# Rolling reference takes priority; fall back to static ref
active_ref_path = prev_path or ref_path
references = []
if active_ref_path:
with open(active_ref_path, "rb") as f:
ref_bytes_data = f.read()
references = [ServeReferenceAudio(audio=ref_bytes_data, text=ref_text or "")]
# Optionally split long text at paragraph boundaries
texts = _fish_split_text(text) if auto_split and len(text) > 600 else [text]
common = dict(
references=references,
temperature=max(0.1, min(1.0, temperature)),
top_p=max(0.1, min(1.0, top_p)),
repetition_penalty=max(0.9, min(2.0, repetition_penalty)),
max_new_tokens=max(64, min(8192, max_new_tokens)),
chunk_length=max(100, min(600, chunk_length)),
latency=latency if latency in ("normal", "balanced") else "normal",
normalize=normalize,
seed=seed,
format="wav",
streaming=False,
)
all_audio = []
sample_rate = 44100
for txt in texts:
req = ServeTTSRequest(text=txt, **common)
for result in engine.inference(req):
if result.code == "header":
if isinstance(result.audio, tuple):
sample_rate = result.audio[0]
elif result.code in ("segment", "final"):
if isinstance(result.audio, tuple):
all_audio.append(result.audio[1])
if not all_audio:
raise ValueError("Fish-Speech: no audio generated")
combined = np.concatenate(all_audio)
# Extract tail for rolling reference (last N seconds)
tail_secs = max(3.0, min(10.0, rolling_ref_secs))
tail_samples = int(tail_secs * sample_rate)
tail_audio = combined[-tail_samples:] if len(combined) > tail_samples else combined
buf = io.BytesIO()
sf.write(buf, combined, sample_rate, format="WAV")
tail_buf = io.BytesIO()
sf.write(tail_buf, tail_audio, sample_rate, format="WAV")
return buf.getvalue(), tail_buf.getvalue(), sample_rate
try:
wav_bytes, tail_bytes, _ = await asyncio.get_event_loop().run_in_executor(None, _run)
tail_b64 = base64.b64encode(tail_bytes).decode()
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={"X-Rolling-Ref-B64": tail_b64, "Access-Control-Expose-Headers": "X-Rolling-Ref-B64"},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if cleanup_ref and ref_path and os.path.exists(ref_path):
os.unlink(ref_path)
if cleanup_prev and prev_path and os.path.exists(prev_path):
os.unlink(prev_path)
@app.post("/generate/voxtral")
async def generate_voxtral(
text: str = Form(...),
ref_wav: UploadFile = File(None),
ref_text: str = Form(""),
):
"""Generate speech via Voxtral TTS (vLLM-Omni server at VOXTRAL_URL).
Requires a running `python3 voxtral_server.py` instance.
Optionally upload a reference WAV for voice cloning; falls back to
narrator_reference.wav in the same directory as this script.
"""
if not _engine_enabled("voxtral"):
raise HTTPException(status_code=503, detail="Voxtral engine not enabled on this server.")
try:
import httpx as _httpx
except ImportError:
raise HTTPException(status_code=503, detail="httpx not installed. Run: pip install httpx")
# Preprocess French text for better prosody
processed = preprocess_french(text)
ref_b64: str | None = None
if ref_wav and ref_wav.filename:
ref_bytes = await ref_wav.read()
if len(ref_bytes) > MAX_AUDIO_BYTES:
raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(ref_bytes) / 1024 / 1024))
ref_b64 = "data:audio/wav;base64," + base64.b64encode(ref_bytes).decode()
else:
narrator_ref = Path(__file__).parent / "narrator_reference.wav"
if narrator_ref.exists():
ref_b64 = "data:audio/wav;base64," + base64.b64encode(narrator_ref.read_bytes()).decode()
payload: dict = {
"model": _VOXTRAL_MODEL,
"input": processed,
"response_format": "wav",
}
if ref_b64:
payload["ref_audio"] = ref_b64
payload["ref_text"] = ref_text
def _run():
try:
r = _httpx.post(f"{_VOXTRAL_URL}/v1/audio/speech", json=payload, timeout=120.0)
r.raise_for_status()
return r.content
except _httpx.ConnectError:
raise RuntimeError(
f"Voxtral server not reachable at {_VOXTRAL_URL}. "
"Start it with: python3 voxtral_server.py"
)
try:
wav_bytes = await asyncio.to_thread(_run)
return Response(content=wav_bytes, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ─── Entry point ──────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Faster Qwen3-TTS Demo Server")
parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base",
help="Model to preload at startup")
parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--no-preload", action="store_true",
help="Skip model loading at startup")
args = parser.parse_args()
if not args.no_preload:
global _active_model_name, _parakeet
print(f"Loading model: {args.model}")
_startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cuda", dtype=torch.bfloat16)
print("Capturing CUDA graphs…")
_startup_model._warmup(prefill_len=100)
_model_cache[args.model] = _startup_model
_active_model_name = args.model
_prime_preset_voice_cache(_startup_model)
print("TTS model ready.")
if _parakeet_from_pretrained:
print("Loading transcription model (nano-parakeet)…")
_parakeet = _parakeet_from_pretrained(device="cuda")
print("Transcription model ready.")
print(f"Ready. Open http://localhost:{args.port}")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()