| |
| """ |
| TTS Demo Server |
| |
| Qwen3-TTS (Clone / Custom / VoiceDesign) → local RTX 3090 via faster_qwen3_tts |
| Kokoro FR → local (hexgrad/Kokoro-82M) |
| F5-TTS FR → local (RASPIAUDIO checkpoint) |
| Chatterbox → local (ResembleAI) |
| Fish-Speech → local (fishaudio/fish-speech-1.5) |
| """ |
|
|
| import argparse |
| import asyncio |
| import base64 |
| from collections import OrderedDict |
| import hashlib |
| import io |
| import json |
| import os |
| import re |
| import sys |
| import tempfile |
| import threading |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import uvicorn |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse |
|
|
| |
| _VOXTRAL_URL = os.environ.get("VOXTRAL_URL", "http://localhost:8000") |
| _VOXTRAL_MODEL = "mistralai/Voxtral-4B-TTS-2603" |
|
|
| |
| FISH_SPEECH_REPO = Path("/tmp/fish-speech") |
| FISH_SPEECH_MODEL = Path("/root/fish-speech-model") |
|
|
| |
| import torchaudio as _torchaudio |
| if not hasattr(_torchaudio, "list_audio_backends"): |
| _torchaudio.list_audio_backends = lambda: ["ffmpeg", "sox"] |
| _fish_engine = None |
| _fish_lock = threading.Lock() |
|
|
| def _get_fish_engine(): |
| global _fish_engine |
| if _fish_engine is not None: |
| return _fish_engine |
| with _fish_lock: |
| if _fish_engine is not None: |
| return _fish_engine |
| if not FISH_SPEECH_REPO.exists(): |
| raise RuntimeError("Fish-Speech repo not found at /tmp/fish-speech. Run: git clone https://github.com/fishaudio/fish-speech /tmp/fish-speech && cd /tmp/fish-speech && git checkout v1.5.1") |
| if not FISH_SPEECH_MODEL.exists(): |
| raise RuntimeError("Fish-Speech model not found at /root/fish-speech-model. Download with: python3 -c \"from huggingface_hub import snapshot_download; snapshot_download('fishaudio/fish-speech-1.5', local_dir='/root/fish-speech-model')\"") |
| sys.path.insert(0, str(FISH_SPEECH_REPO)) |
| from fish_speech.models.text2semantic.inference import launch_thread_safe_queue |
| from fish_speech.models.vqgan.inference import load_model as load_decoder_model |
| from fish_speech.inference_engine import TTSInferenceEngine |
| import torch as _torch |
| device = "cuda" if _torch.cuda.is_available() else "cpu" |
| precision = _torch.bfloat16 |
| use_compile = device == "cuda" |
| llama_queue = launch_thread_safe_queue( |
| checkpoint_path=str(FISH_SPEECH_MODEL), |
| device=device, |
| precision=precision, |
| compile=use_compile, |
| ) |
| decoder_model = load_decoder_model( |
| config_name="firefly_gan_vq", |
| checkpoint_path=str(FISH_SPEECH_MODEL / "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"), |
| device=device, |
| ) |
| _fish_engine = TTSInferenceEngine( |
| llama_queue=llama_queue, |
| decoder_model=decoder_model, |
| precision=precision, |
| compile=use_compile, |
| ) |
| return _fish_engine |
|
|
|
|
| |
| _ENABLED_ENV = os.environ.get("ENABLED_ENGINES", "all").lower() |
| ENABLED_ENGINES: set[str] | None = ( |
| None if _ENABLED_ENV == "all" |
| else {e.strip() for e in _ENABLED_ENV.split(",") if e.strip()} |
| ) |
|
|
| def _engine_enabled(name: str) -> bool: |
| return ENABLED_ENGINES is None or name in ENABLED_ENGINES |
|
|
| |
| _kokoro_pipelines: dict[str, object] = {} |
| _kokoro_lock = threading.Lock() |
|
|
| KOKORO_VOICES_FR = { |
| |
| "ff_siwis": "Siwis — FR Femme ★", |
| |
| "af_heart": "Heart — EN Femme ★", |
| "af_bella": "Bella — EN Femme", |
| "af_nicole": "Nicole — EN Femme (ASMR)", |
| "af_sarah": "Sarah — EN Femme", |
| "af_sky": "Sky — EN Femme", |
| |
| "am_echo": "Echo — EN Homme", |
| "am_michael": "Michael — EN Homme", |
| "am_adam": "Adam — EN Homme", |
| |
| "bf_emma": "Emma — EN(UK) Femme", |
| "bf_isabella":"Isabella — EN(UK) Femme", |
| |
| "bm_george": "George — EN(UK) Homme", |
| "bm_lewis": "Lewis — EN(UK) Homme", |
| } |
|
|
| _VOICE_LANG_CODE = { |
| "ff": "f", |
| "af": "a", |
| "am": "a", |
| "bf": "b", |
| "bm": "b", |
| } |
|
|
| def _get_kokoro(voice: str = "ff_siwis"): |
| prefix = voice[:2] if len(voice) >= 2 else "ff" |
| lang_code = _VOICE_LANG_CODE.get(prefix, "f") |
| with _kokoro_lock: |
| if lang_code not in _kokoro_pipelines: |
| from kokoro import KPipeline |
| _kokoro_pipelines[lang_code] = KPipeline(lang_code=lang_code) |
| return _kokoro_pipelines[lang_code] |
|
|
| |
| _chatterbox_model = None |
| _chatterbox_lock = threading.Lock() |
|
|
| def _get_chatterbox(): |
| global _chatterbox_model |
| if _chatterbox_model is None: |
| with _chatterbox_lock: |
| if _chatterbox_model is None: |
| from chatterbox.tts import ChatterboxTTS |
| _device = "cuda" if torch.cuda.is_available() else "cpu" |
| _chatterbox_model = ChatterboxTTS.from_pretrained(device=_device) |
| return _chatterbox_model |
|
|
| |
| _f5_model = None |
| _f5_lock = threading.Lock() |
| F5_REPO = "RASPIAUDIO/F5-French-MixedSpeakers-reduced" |
|
|
| def _get_f5(): |
| global _f5_model |
| if _f5_model is None: |
| with _f5_lock: |
| if _f5_model is None: |
| from f5_tts.api import F5TTS |
| from huggingface_hub import hf_hub_download |
| ckpt = hf_hub_download(F5_REPO, "model_last_reduced.pt") |
| vocab = hf_hub_download(F5_REPO, "vocab.txt") |
| _f5_model = F5TTS(model="F5TTS_v1_Base", ckpt_file=ckpt, vocab_file=vocab) |
| return _f5_model |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| try: |
| from faster_qwen3_tts import FasterQwen3TTS |
| except ImportError: |
| print("Warning: faster_qwen3_tts not found — Qwen3 engine disabled. Install with: pip install faster-qwen3-tts") |
| FasterQwen3TTS = None |
|
|
| try: |
| from nano_parakeet import from_pretrained as _parakeet_from_pretrained |
| except ImportError: |
| _parakeet_from_pretrained = None |
|
|
| _ALL_MODELS = [ |
| "Qwen/Qwen3-TTS-12Hz-0.6B-Base", |
| "Qwen/Qwen3-TTS-12Hz-1.7B-Base", |
| "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", |
| "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", |
| "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", |
| ] |
|
|
| _active_models_env = os.environ.get("ACTIVE_MODELS", "") |
| if _active_models_env: |
| _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()} |
| AVAILABLE_MODELS = [m for m in _ALL_MODELS if m in _allowed] |
| else: |
| AVAILABLE_MODELS = list(_ALL_MODELS) |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets")) |
| PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt" |
| PRESET_REFS = [ |
| ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"), |
| ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"), |
| ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"), |
| ] |
|
|
| _GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main" |
| _PRESET_REMOTE = { |
| "ref_audio": f"{_GITHUB_RAW}/ref_audio.wav", |
| "ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav", |
| "ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav", |
| } |
| _TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt" |
|
|
|
|
| def _fetch_preset_assets(): |
| import urllib.request |
| _ASSET_DIR.mkdir(parents=True, exist_ok=True) |
| PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True) |
| if not PRESET_TRANSCRIPTS.exists(): |
| try: |
| urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS) |
| except Exception as e: |
| print(f"Warning: could not fetch transcripts: {e}") |
| for key, path, _ in PRESET_REFS: |
| if not path.exists() and key in _PRESET_REMOTE: |
| try: |
| urllib.request.urlretrieve(_PRESET_REMOTE[key], path) |
| print(f"Downloaded {path.name}") |
| except Exception as e: |
| print(f"Warning: could not fetch {key}: {e}") |
|
|
| _preset_refs: dict[str, dict] = {} |
|
|
|
|
| def _load_preset_transcripts(): |
| if not PRESET_TRANSCRIPTS.exists(): |
| return {} |
| transcripts = {} |
| for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines(): |
| if ":" not in line: |
| continue |
| key_part, text = line.split(":", 1) |
| key = key_part.split("(")[0].strip() |
| transcripts[key] = text.strip() |
| return transcripts |
|
|
|
|
| def _load_preset_refs(): |
| transcripts = _load_preset_transcripts() |
| for key, path, label in PRESET_REFS: |
| if not path.exists(): |
| continue |
| content = path.read_bytes() |
| cached_path = _get_cached_ref_path(content) |
| _preset_refs[key] = { |
| "id": key, "label": label, "filename": path.name, |
| "path": cached_path, "ref_text": transcripts.get(key, ""), |
| "audio_b64": base64.b64encode(content).decode(), |
| } |
|
|
|
|
| def _prime_preset_voice_cache(model): |
| if not _preset_refs: |
| return |
| for preset in _preset_refs.values(): |
| for xvec_only in (True, False): |
| try: |
| model._prepare_generation( |
| text="Hello.", ref_audio=preset["path"], ref_text=preset["ref_text"], |
| language="English", xvec_only=xvec_only, non_streaming_mode=True, |
| ) |
| except Exception: |
| continue |
|
|
|
|
| app = FastAPI(title="Faster Qwen3-TTS Demo") |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| _model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict() |
| _model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "2")) |
| _active_model_name: str | None = None |
| _loading = False |
| _ref_cache: dict[str, str] = {} |
| _ref_cache_lock = threading.Lock() |
| _parakeet = None |
| _generation_lock = asyncio.Lock() |
| _generation_waiters: int = 0 |
|
|
| MAX_TEXT_CHARS = 15000 |
| MAX_AUDIO_BYTES = 10 * 1024 * 1024 |
| _AUDIO_TOO_LARGE_MSG = ( |
| "Audio file too large ({size_mb:.1f} MB). " |
| "Please upload a shorter recording (under 1 minute)." |
| ) |
|
|
|
|
| |
|
|
| FRENCH_ABBREVS = [ |
| (r'\bM\.\s+', 'Monsieur '), |
| (r'\bMme\.?\s+', 'Madame '), |
| (r'\bMlle\.?\s+', 'Mademoiselle '), |
| (r'\bDr\.?\s+', 'Docteur '), |
| (r'\bPr\.?\s+', 'Professeur '), |
| (r'\bSt\.?\s+', 'Saint '), |
| (r'\betc\.(?!\w)', 'et cetera'), |
| (r'\bn°\s*(\d+)', r'numéro \1'), |
| (r'\b(\d{1,2})\s*h\s*(\d{2})\b', r'\1 heures \2'), |
| (r'\b(\d{1,2})\s*h\b', r'\1 heures'), |
| (r'\bp\.\s*(\d+)\b', r'page \1'), |
| ] |
|
|
| FRENCH_NARRATOR_PROMPT = ( |
| "Narrateur professionnel de livres audio français, voix grave, chaude et captivante. " |
| "Débit naturellement mesuré — ni précipité ni traînant. " |
| "Respectez scrupuleusement la ponctuation : " |
| "légère pause aux virgules, souffle marqué aux points, " |
| "pause longue et respirée aux doubles sauts de paragraphe. " |
| "Aux guillemets « », adoptez un registre légèrement plus direct et personnel pour le dialogue, " |
| "puis revenez au ton narratif après le ». " |
| "Les points de suspension (...) et les tirets (—) appellent une vraie pause respiratoire. " |
| "Ton légèrement plus grave et plus riche que la conversation ordinaire, " |
| "comme un conteur au coin du feu. " |
| "Restez cohérent du début à la fin — même timbre, même rythme de fond." |
| ) |
|
|
|
|
| def preprocess_french(text: str) -> str: |
| """French TTS prosody preprocessing (arXiv:2508.17494). |
| Expands abbreviations and inserts natural pause markers for the TTS model. |
| """ |
| for pattern, repl in FRENCH_ABBREVS: |
| text = re.sub(pattern, repl, text) |
| |
| text = re.sub(r'«\s*', '« ', text) |
| text = re.sub(r'\s*»', ' »', text) |
| |
| text = re.sub(r'\n\s*\n', ' ... ', text) |
| |
| text = re.sub(r'\s*—\s*', ' — ', text) |
| |
| text = re.sub(r'\.{3,}', '...', text) |
| |
| text = re.sub(r'[ \t]+', ' ', text) |
| text = re.sub(r'([!?])', r'\1 ', text) |
| return text.strip() |
|
|
|
|
| |
|
|
| def _to_wav_b64(audio: np.ndarray, sr: int) -> str: |
| if audio.dtype != np.float32: |
| audio = audio.astype(np.float32) |
| if audio.ndim > 1: |
| audio = audio.squeeze() |
| buf = io.BytesIO() |
| sf.write(buf, audio, sr, format="WAV", subtype="PCM_16") |
| return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
| def _concat_audio(audio_list) -> np.ndarray: |
| if isinstance(audio_list, np.ndarray): |
| return audio_list.astype(np.float32).squeeze() |
| parts = [np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0] |
| return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32) |
|
|
|
|
| def _get_cached_ref_path(content: bytes) -> str: |
| digest = hashlib.sha1(content).hexdigest() |
| with _ref_cache_lock: |
| cached = _ref_cache.get(digest) |
| if cached and os.path.exists(cached): |
| return cached |
| path = Path(tempfile.gettempdir()) / f"qwen3tts_ref_{digest}.wav" |
| if not path.exists(): |
| path.write_bytes(content) |
| _ref_cache[digest] = str(path) |
| return str(path) |
|
|
|
|
| |
|
|
| _fetch_preset_assets() |
| _load_preset_refs() |
|
|
| @app.get("/") |
| async def root(): |
| return FileResponse(Path(__file__).parent / "index.html") |
|
|
|
|
| @app.post("/transcribe") |
| async def transcribe_audio(audio: UploadFile = File(...)): |
| if _parakeet is None: |
| raise HTTPException(status_code=503, detail="Transcription model not loaded") |
| content = await audio.read() |
| if len(content) > MAX_AUDIO_BYTES: |
| raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024)) |
| def run(): |
| import torchaudio |
| wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False) |
| if wav.ndim > 1: |
| wav = wav.mean(axis=1) |
| wav_t = torch.from_numpy(wav) |
| if sr != 16000: |
| wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0) |
| return _parakeet.transcribe(wav_t.cuda()) |
| text = await asyncio.to_thread(run) |
| return {"text": text} |
|
|
|
|
| @app.get("/status") |
| async def get_status(): |
| speakers = [] |
| model_type = None |
| active = _model_cache.get(_active_model_name) if _active_model_name else None |
| if active is not None: |
| try: |
| model_type = active.model.model.tts_model_type |
| speakers = active.model.get_supported_speakers() or [] |
| except Exception: |
| speakers = [] |
| return { |
| "loaded": active is not None, |
| "model": _active_model_name, |
| "loading": _loading, |
| "available_models": AVAILABLE_MODELS, |
| "model_type": model_type, |
| "speakers": speakers, |
| "transcription_available": _parakeet is not None, |
| "preset_refs": [{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values()], |
| "queue_depth": _generation_waiters, |
| "cached_models": list(_model_cache.keys()), |
| "kokoro_voices": KOKORO_VOICES_FR, |
| "voxtral_url": _VOXTRAL_URL, |
| "voxtral_model": _VOXTRAL_MODEL, |
| } |
|
|
|
|
| @app.post("/load") |
| async def load_model(model_id: str = Form(...)): |
| if not _engine_enabled("qwen3") or FasterQwen3TTS is None: |
| raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.") |
| global _active_model_name, _loading |
| if model_id in _model_cache: |
| _active_model_name = model_id |
| _model_cache.move_to_end(model_id) |
| return {"status": "already_loaded", "model": model_id} |
| _loading = True |
| def _do_load(): |
| global _active_model_name, _loading |
| try: |
| if len(_model_cache) >= _model_cache_max: |
| evicted, _ = _model_cache.popitem(last=False) |
| print(f"Model cache full — evicted: {evicted}") |
| new_model = FasterQwen3TTS.from_pretrained(model_id, device="cuda", dtype=torch.bfloat16) |
| print("Capturing CUDA graphs…") |
| new_model._warmup(prefill_len=100) |
| _model_cache[model_id] = new_model |
| _model_cache.move_to_end(model_id) |
| _active_model_name = model_id |
| _prime_preset_voice_cache(new_model) |
| print("CUDA graphs captured — model ready.") |
| finally: |
| _loading = False |
| async with _generation_lock: |
| await asyncio.to_thread(_do_load) |
| return {"status": "loaded", "model": model_id} |
|
|
|
|
| @app.post("/generate/stream") |
| async def generate_stream( |
| text: str = Form(...), |
| language: str = Form("English"), |
| mode: str = Form("voice_clone"), |
| ref_text: str = Form(""), |
| speaker: str = Form(""), |
| instruct: str = Form(""), |
| xvec_only: bool = Form(True), |
| chunk_size: int = Form(8), |
| temperature: float = Form(0.7), |
| top_k: int = Form(30), |
| repetition_penalty: float = Form(1.1), |
| ref_preset: str = Form(""), |
| ref_audio: UploadFile = File(None), |
| seed: Optional[int] = Form(None), |
| ): |
| if not _engine_enabled("qwen3") or FasterQwen3TTS is None: |
| raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.") |
| if not _active_model_name or _active_model_name not in _model_cache: |
| raise HTTPException(status_code=400, detail="Modèle non chargé. Cliquez sur 'Load' d'abord.") |
| if len(text) > MAX_TEXT_CHARS: |
| raise HTTPException(status_code=400, detail=f"Texte trop long ({len(text)} chars). Max {MAX_TEXT_CHARS}.") |
|
|
| tmp_path = None |
| tmp_is_cached = False |
|
|
| if ref_preset and ref_preset in _preset_refs: |
| preset = _preset_refs[ref_preset] |
| tmp_path = preset["path"] |
| tmp_is_cached = True |
| if not ref_text: |
| ref_text = preset["ref_text"] |
| elif ref_audio and ref_audio.filename: |
| content = await ref_audio.read() |
| if len(content) > MAX_AUDIO_BYTES: |
| raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024)) |
| tmp_path = _get_cached_ref_path(content) |
| tmp_is_cached = True |
|
|
| loop = asyncio.get_event_loop() |
| queue: asyncio.Queue[str | None] = asyncio.Queue() |
|
|
| def run_generation(): |
| try: |
| model = _model_cache.get(_active_model_name) |
| if model is None: |
| raise RuntimeError("No model loaded.") |
| if seed is not None: |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| t0 = time.perf_counter() |
| total_audio_s = 0.0 |
| voice_clone_ms = 0.0 |
|
|
| if mode == "voice_clone": |
| gen = model.generate_voice_clone_streaming( |
| text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, |
| xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature, |
| top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800, |
| ) |
| elif mode == "custom": |
| if not speaker: |
| raise ValueError("Speaker ID required for custom voice") |
| gen = model.generate_custom_voice_streaming( |
| text=text, speaker=speaker, language=language, instruct=instruct, |
| chunk_size=chunk_size, temperature=temperature, top_k=top_k, |
| repetition_penalty=repetition_penalty, max_new_tokens=1800, |
| ) |
| else: |
| gen = model.generate_voice_design_streaming( |
| text=text, instruct=instruct, language=language, chunk_size=chunk_size, |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, |
| max_new_tokens=1800, |
| ) |
|
|
| ttfa_ms = None |
| total_gen_ms = 0.0 |
| first_audio = next(gen, None) |
| if first_audio is not None: |
| audio_chunk, sr, timing = first_audio |
| wall_first_ms = (time.perf_counter() - t0) * 1000 |
| model_ms = timing.get("prefill_ms", 0) + timing.get("decode_ms", 0) |
| voice_clone_ms = max(0.0, wall_first_ms - model_ms) |
| total_gen_ms += timing.get('prefill_ms', 0) + timing.get('decode_ms', 0) |
| if ttfa_ms is None: |
| ttfa_ms = total_gen_ms |
| audio_chunk = _concat_audio(audio_chunk) |
| dur = len(audio_chunk) / sr |
| total_audio_s += dur |
| rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0 |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps({ |
| "type": "chunk", "audio_b64": _to_wav_b64(audio_chunk, sr), |
| "sample_rate": sr, "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms), |
| "rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3), |
| "elapsed_ms": round(time.perf_counter() - t0, 3) * 1000, |
| })) |
|
|
| for audio_chunk, sr, timing in gen: |
| total_gen_ms += timing.get('prefill_ms', 0) + timing.get('decode_ms', 0) |
| if ttfa_ms is None: |
| ttfa_ms = total_gen_ms |
| audio_chunk = _concat_audio(audio_chunk) |
| dur = len(audio_chunk) / sr |
| total_audio_s += dur |
| rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0 |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps({ |
| "type": "chunk", "audio_b64": _to_wav_b64(audio_chunk, sr), |
| "sample_rate": sr, "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms), |
| "rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3), |
| "elapsed_ms": round(time.perf_counter() - t0, 3) * 1000, |
| })) |
|
|
| rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0 |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps({ |
| "type": "done", "ttfa_ms": round(ttfa_ms) if ttfa_ms else 0, |
| "voice_clone_ms": round(voice_clone_ms), "rtf": round(rtf, 3), |
| "total_audio_s": round(total_audio_s, 3), |
| "total_ms": round((time.perf_counter() - t0) * 1000), |
| })) |
|
|
| except Exception as e: |
| import traceback |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps({ |
| "type": "error", "message": str(e), "detail": traceback.format_exc() |
| })) |
| finally: |
| loop.call_soon_threadsafe(queue.put_nowait, None) |
| if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: |
| os.unlink(tmp_path) |
|
|
| async def sse(): |
| global _generation_waiters |
| lock_acquired = False |
| _generation_waiters += 1 |
| people_ahead = _generation_waiters - 1 + (1 if _generation_lock.locked() else 0) |
| try: |
| if people_ahead > 0: |
| yield f"data: {json.dumps({'type': 'queued', 'position': people_ahead})}\n\n" |
| await _generation_lock.acquire() |
| lock_acquired = True |
| _generation_waiters -= 1 |
| threading.Thread(target=run_generation, daemon=True).start() |
| while True: |
| msg = await queue.get() |
| if msg is None: |
| break |
| yield f"data: {msg}\n\n" |
| except asyncio.CancelledError: |
| pass |
| finally: |
| if lock_acquired: |
| _generation_lock.release() |
| else: |
| _generation_waiters -= 1 |
|
|
| return StreamingResponse(sse(), media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) |
|
|
|
|
| @app.post("/generate") |
| async def generate_non_streaming( |
| text: str = Form(...), |
| language: str = Form("English"), |
| mode: str = Form("voice_clone"), |
| ref_text: str = Form(""), |
| speaker: str = Form(""), |
| instruct: str = Form(""), |
| xvec_only: bool = Form(True), |
| temperature: float = Form(0.7), |
| top_k: int = Form(30), |
| repetition_penalty: float = Form(1.1), |
| ref_preset: str = Form(""), |
| ref_audio: UploadFile = File(None), |
| seed: Optional[int] = Form(None), |
| ): |
| if not _engine_enabled("qwen3") or FasterQwen3TTS is None: |
| raise HTTPException(status_code=503, detail="Qwen3 engine not available. Install faster-qwen3-tts and use a GPU server.") |
| if not _active_model_name or _active_model_name not in _model_cache: |
| raise HTTPException(status_code=400, detail="Modèle non chargé.") |
| if len(text) > MAX_TEXT_CHARS: |
| raise HTTPException(status_code=400, detail=f"Texte trop long ({len(text)} chars).") |
|
|
| tmp_path = None |
| tmp_is_cached = False |
| if ref_preset and ref_preset in _preset_refs: |
| preset = _preset_refs[ref_preset] |
| tmp_path = preset["path"] |
| tmp_is_cached = True |
| if not ref_text: |
| ref_text = preset["ref_text"] |
| elif ref_audio and ref_audio.filename: |
| content = await ref_audio.read() |
| if len(content) > MAX_AUDIO_BYTES: |
| raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content)/1024/1024)) |
| tmp_path = _get_cached_ref_path(content) |
| tmp_is_cached = True |
|
|
| def run(): |
| model = _model_cache.get(_active_model_name) |
| if model is None: |
| raise RuntimeError("No model loaded.") |
| if seed is not None: |
| |
| import random as _rnd |
| _rnd.seed(seed) |
| np.random.seed(seed % (2**31)) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| |
| _prev_det = torch.backends.cudnn.deterministic |
| _prev_bench = torch.backends.cudnn.benchmark |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| else: |
| _prev_det = _prev_bench = None |
| t0 = time.perf_counter() |
| try: |
| if mode == "voice_clone": |
| audio_list, sr = model.generate_voice_clone( |
| text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, |
| xvec_only=xvec_only, temperature=temperature, top_k=top_k, |
| repetition_penalty=repetition_penalty, max_new_tokens=1800, |
| ) |
| elif mode == "custom": |
| if not speaker: |
| raise ValueError("Speaker ID required") |
| audio_list, sr = model.generate_custom_voice( |
| text=text, speaker=speaker, language=language, instruct=instruct, |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800, |
| ) |
| else: |
| audio_list, sr = model.generate_voice_design( |
| text=text, instruct=instruct, language=language, |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=1800, |
| ) |
| finally: |
| |
| if _prev_det is not None: |
| torch.backends.cudnn.deterministic = _prev_det |
| torch.backends.cudnn.benchmark = _prev_bench |
| elapsed = time.perf_counter() - t0 |
| audio = _concat_audio(audio_list) |
| dur = len(audio) / sr |
| return audio, sr, elapsed, dur |
|
|
| global _generation_waiters |
| _generation_waiters += 1 |
| lock_acquired = False |
| try: |
| await _generation_lock.acquire() |
| lock_acquired = True |
| _generation_waiters -= 1 |
| audio, sr, elapsed, dur = await asyncio.to_thread(run) |
| rtf = dur / elapsed if elapsed > 0 else 0.0 |
| return JSONResponse({ |
| "audio_b64": _to_wav_b64(audio, sr), |
| "sample_rate": sr, |
| "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}, |
| }) |
| finally: |
| if lock_acquired: |
| _generation_lock.release() |
| else: |
| _generation_waiters -= 1 |
| if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: |
| os.unlink(tmp_path) |
|
|
|
|
| @app.post("/generate/kokoro_fr") |
| async def generate_kokoro_fr( |
| text: str = Form(...), |
| voice: str = Form("ff_siwis"), |
| speed: float = Form(1.0), |
| ): |
| if not _engine_enabled("kokoro"): |
| raise HTTPException(status_code=503, detail="Kokoro engine not enabled on this server.") |
| if voice not in KOKORO_VOICES_FR: |
| voice = "ff_siwis" |
| speed = max(0.5, min(2.0, speed)) |
| def _run(): |
| pipeline = _get_kokoro(voice) |
| chunks = [] |
| for _gs, _ps, audio in pipeline(text, voice=voice, speed=speed): |
| chunks.append(audio.numpy() if hasattr(audio, "numpy") else audio) |
| if not chunks: |
| raise ValueError("Kokoro: no audio generated") |
| combined = np.concatenate(chunks) |
| buf = io.BytesIO() |
| sf.write(buf, combined, 24000, format="WAV") |
| buf.seek(0) |
| return buf.read() |
| try: |
| wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run) |
| return Response(content=wav_bytes, media_type="audio/wav") |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/generate/f5_fr") |
| async def generate_f5_fr( |
| text: str = Form(...), |
| ref_wav: UploadFile = File(None), |
| ref_text: str = Form(""), |
| speed: float = Form(1.0), |
| nfe_step: int = Form(32), |
| cross_fade_duration: float = Form(0.15), |
| seed: Optional[int] = Form(None), |
| ): |
| if not _engine_enabled("f5"): |
| raise HTTPException(status_code=503, detail="F5-TTS engine not enabled on this server.") |
| ref_path = None |
| cleanup_ref = False |
| if ref_wav and ref_wav.filename: |
| ref_bytes = await ref_wav.read() |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp.write(ref_bytes) |
| ref_path = tmp.name |
| cleanup_ref = True |
| else: |
| |
| import glob as _glob |
| bundled_candidates = _glob.glob( |
| "/usr/local/lib/python3*/dist-packages/f5_tts/infer/examples/basic/basic_ref_en.wav" |
| ) |
| for candidate in [ |
| Path(__file__).parent / "narrator_ref.wav", |
| ] + [Path(p) for p in sorted(bundled_candidates, reverse=True)]: |
| if candidate.exists(): |
| ref_path = str(candidate) |
| if "basic_ref_en" in ref_path and not ref_text: |
| ref_text = "Some call me nature, others call me mother nature." |
| break |
| if ref_path is None: |
| raise HTTPException( |
| status_code=400, |
| detail="F5-TTS nécessite un fichier audio de référence. Veuillez en uploader un via 'Réf audio'." |
| ) |
|
|
| def _run(): |
| model = _get_f5() |
| out_tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| out_path = out_tmp.name |
| out_tmp.close() |
| try: |
| if seed is not None: |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| model.infer( |
| ref_file=ref_path, ref_text=ref_text or "", gen_text=text, |
| file_wave=out_path, |
| speed=max(0.5, min(2.0, speed)), |
| nfe_step=max(8, min(64, nfe_step)), |
| cross_fade_duration=max(0.0, min(0.5, cross_fade_duration)), |
| ) |
| with open(out_path, "rb") as f: |
| return f.read() |
| finally: |
| if os.path.exists(out_path): |
| os.unlink(out_path) |
| if cleanup_ref and ref_path and os.path.exists(ref_path): |
| os.unlink(ref_path) |
|
|
| try: |
| wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run) |
| return Response(content=wav_bytes, media_type="audio/wav") |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/generate/chatterbox") |
| async def generate_chatterbox( |
| text: str = Form(...), |
| ref_wav: UploadFile = File(None), |
| exaggeration: float = Form(0.5), |
| cfg_weight: float = Form(0.5), |
| temperature: float = Form(0.8), |
| seed: Optional[int] = Form(None), |
| ): |
| if not _engine_enabled("chatterbox"): |
| raise HTTPException(status_code=503, detail="Chatterbox engine not enabled on this server.") |
| ref_path = None |
| cleanup_ref = False |
| if ref_wav and ref_wav.filename: |
| ref_bytes = await ref_wav.read() |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp.write(ref_bytes) |
| ref_path = tmp.name |
| cleanup_ref = True |
|
|
| def _run(): |
| import torchaudio |
| if seed is not None: |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| model = _get_chatterbox() |
| kwargs = dict(audio_prompt_path=ref_path, exaggeration=exaggeration, cfg_weight=cfg_weight) |
| try: |
| wav = model.generate(text, temperature=max(0.1, min(2.0, temperature)), **kwargs) |
| except TypeError: |
| wav = model.generate(text, **kwargs) |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as out_tmp: |
| out_path = out_tmp.name |
| try: |
| torchaudio.save(out_path, wav, model.sr) |
| with open(out_path, "rb") as f: |
| return f.read() |
| finally: |
| if os.path.exists(out_path): |
| os.unlink(out_path) |
|
|
| try: |
| wav_bytes = await asyncio.get_event_loop().run_in_executor(None, _run) |
| return Response(content=wav_bytes, media_type="audio/wav") |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| if cleanup_ref and ref_path and os.path.exists(ref_path): |
| os.unlink(ref_path) |
|
|
|
|
| def _fish_split_text(text: str, max_words: int = 200) -> list[str]: |
| """Split long text into paragraph groups to avoid truncation and improve quality.""" |
| paragraphs = [p.strip() for p in text.split("\n") if p.strip()] |
| if len(paragraphs) <= 1: |
| return [text] |
| chunks, current, current_words = [], [], 0 |
| for p in paragraphs: |
| words = len(p.split()) |
| if current_words + words > max_words and current: |
| chunks.append("\n\n".join(current)) |
| current, current_words = [p], words |
| else: |
| current.append(p) |
| current_words += words |
| if current: |
| chunks.append("\n\n".join(current)) |
| return chunks if len(chunks) > 1 else [text] |
|
|
|
|
| @app.post("/generate/fish") |
| async def generate_fish( |
| text: str = Form(...), |
| ref_wav: UploadFile = File(None), |
| ref_text: str = Form(""), |
| prev_wav: UploadFile = File(None), |
| temperature: float = Form(0.8), |
| top_p: float = Form(0.8), |
| repetition_penalty: float = Form(1.1), |
| max_new_tokens: int = Form(1024), |
| chunk_length: int = Form(200), |
| latency: str = Form("normal"), |
| normalize: bool = Form(True), |
| seed: Optional[int] = Form(None), |
| auto_split: bool = Form(False), |
| rolling_ref_secs: float = Form(6.0), |
| ): |
| if not _engine_enabled("fish"): |
| raise HTTPException(status_code=503, detail="Fish-Speech engine not enabled on this server.") |
| ref_path = None |
| prev_path = None |
| cleanup_ref = False |
| cleanup_prev = False |
|
|
| if ref_wav and ref_wav.filename: |
| ref_bytes = await ref_wav.read() |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp.write(ref_bytes) |
| ref_path = tmp.name |
| cleanup_ref = True |
|
|
| |
| if prev_wav and prev_wav.filename: |
| prev_bytes = await prev_wav.read() |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp.write(prev_bytes) |
| prev_path = tmp.name |
| cleanup_prev = True |
|
|
| def _run(): |
| sys.path.insert(0, str(FISH_SPEECH_REPO)) |
| from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio |
| engine = _get_fish_engine() |
|
|
| |
| active_ref_path = prev_path or ref_path |
| references = [] |
| if active_ref_path: |
| with open(active_ref_path, "rb") as f: |
| ref_bytes_data = f.read() |
| references = [ServeReferenceAudio(audio=ref_bytes_data, text=ref_text or "")] |
|
|
| |
| texts = _fish_split_text(text) if auto_split and len(text) > 600 else [text] |
|
|
| common = dict( |
| references=references, |
| temperature=max(0.1, min(1.0, temperature)), |
| top_p=max(0.1, min(1.0, top_p)), |
| repetition_penalty=max(0.9, min(2.0, repetition_penalty)), |
| max_new_tokens=max(64, min(8192, max_new_tokens)), |
| chunk_length=max(100, min(600, chunk_length)), |
| latency=latency if latency in ("normal", "balanced") else "normal", |
| normalize=normalize, |
| seed=seed, |
| format="wav", |
| streaming=False, |
| ) |
|
|
| all_audio = [] |
| sample_rate = 44100 |
| for txt in texts: |
| req = ServeTTSRequest(text=txt, **common) |
| for result in engine.inference(req): |
| if result.code == "header": |
| if isinstance(result.audio, tuple): |
| sample_rate = result.audio[0] |
| elif result.code in ("segment", "final"): |
| if isinstance(result.audio, tuple): |
| all_audio.append(result.audio[1]) |
|
|
| if not all_audio: |
| raise ValueError("Fish-Speech: no audio generated") |
| combined = np.concatenate(all_audio) |
|
|
| |
| tail_secs = max(3.0, min(10.0, rolling_ref_secs)) |
| tail_samples = int(tail_secs * sample_rate) |
| tail_audio = combined[-tail_samples:] if len(combined) > tail_samples else combined |
|
|
| buf = io.BytesIO() |
| sf.write(buf, combined, sample_rate, format="WAV") |
| tail_buf = io.BytesIO() |
| sf.write(tail_buf, tail_audio, sample_rate, format="WAV") |
|
|
| return buf.getvalue(), tail_buf.getvalue(), sample_rate |
|
|
| try: |
| wav_bytes, tail_bytes, _ = await asyncio.get_event_loop().run_in_executor(None, _run) |
| tail_b64 = base64.b64encode(tail_bytes).decode() |
| return Response( |
| content=wav_bytes, |
| media_type="audio/wav", |
| headers={"X-Rolling-Ref-B64": tail_b64, "Access-Control-Expose-Headers": "X-Rolling-Ref-B64"}, |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| if cleanup_ref and ref_path and os.path.exists(ref_path): |
| os.unlink(ref_path) |
| if cleanup_prev and prev_path and os.path.exists(prev_path): |
| os.unlink(prev_path) |
|
|
|
|
| @app.post("/generate/voxtral") |
| async def generate_voxtral( |
| text: str = Form(...), |
| ref_wav: UploadFile = File(None), |
| ref_text: str = Form(""), |
| ): |
| """Generate speech via Voxtral TTS (vLLM-Omni server at VOXTRAL_URL). |
| |
| Requires a running `python3 voxtral_server.py` instance. |
| Optionally upload a reference WAV for voice cloning; falls back to |
| narrator_reference.wav in the same directory as this script. |
| """ |
| if not _engine_enabled("voxtral"): |
| raise HTTPException(status_code=503, detail="Voxtral engine not enabled on this server.") |
|
|
| try: |
| import httpx as _httpx |
| except ImportError: |
| raise HTTPException(status_code=503, detail="httpx not installed. Run: pip install httpx") |
|
|
| |
| processed = preprocess_french(text) |
|
|
| ref_b64: str | None = None |
| if ref_wav and ref_wav.filename: |
| ref_bytes = await ref_wav.read() |
| if len(ref_bytes) > MAX_AUDIO_BYTES: |
| raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(ref_bytes) / 1024 / 1024)) |
| ref_b64 = "data:audio/wav;base64," + base64.b64encode(ref_bytes).decode() |
| else: |
| narrator_ref = Path(__file__).parent / "narrator_reference.wav" |
| if narrator_ref.exists(): |
| ref_b64 = "data:audio/wav;base64," + base64.b64encode(narrator_ref.read_bytes()).decode() |
|
|
| payload: dict = { |
| "model": _VOXTRAL_MODEL, |
| "input": processed, |
| "response_format": "wav", |
| } |
| if ref_b64: |
| payload["ref_audio"] = ref_b64 |
| payload["ref_text"] = ref_text |
|
|
| def _run(): |
| try: |
| r = _httpx.post(f"{_VOXTRAL_URL}/v1/audio/speech", json=payload, timeout=120.0) |
| r.raise_for_status() |
| return r.content |
| except _httpx.ConnectError: |
| raise RuntimeError( |
| f"Voxtral server not reachable at {_VOXTRAL_URL}. " |
| "Start it with: python3 voxtral_server.py" |
| ) |
|
|
| try: |
| wav_bytes = await asyncio.to_thread(_run) |
| return Response(content=wav_bytes, media_type="audio/wav") |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Faster Qwen3-TTS Demo Server") |
| parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", |
| help="Model to preload at startup") |
| parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860))) |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--no-preload", action="store_true", |
| help="Skip model loading at startup") |
| args = parser.parse_args() |
|
|
| if not args.no_preload: |
| global _active_model_name, _parakeet |
| print(f"Loading model: {args.model}") |
| _startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cuda", dtype=torch.bfloat16) |
| print("Capturing CUDA graphs…") |
| _startup_model._warmup(prefill_len=100) |
| _model_cache[args.model] = _startup_model |
| _active_model_name = args.model |
| _prime_preset_voice_cache(_startup_model) |
| print("TTS model ready.") |
|
|
| if _parakeet_from_pretrained: |
| print("Loading transcription model (nano-parakeet)…") |
| _parakeet = _parakeet_from_pretrained(device="cuda") |
| print("Transcription model ready.") |
|
|
| print(f"Ready. Open http://localhost:{args.port}") |
|
|
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|