#!/usr/bin/env python3 """ Qwen3-TTS Demo Server (Oficial) CPU Optimized & RAM Safe """ import argparse import asyncio import base64 from collections import OrderedDict import hashlib import io import json import os import sys import tempfile import threading import time import gc from pathlib import Path import numpy as np import soundfile as sf import torch import torchaudio import uvicorn from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse, StreamingResponse # OPTIMIZACIÓN CPU torch.set_num_threads(4) # Allow running from any directory sys.path.insert(0, str(Path(__file__).parent.parent)) try: from qwen_tts import Qwen3TTSModel except ImportError: print("Error: qwen-tts no está instalado.") sys.exit(1) from nano_parakeet import from_pretrained as _parakeet_from_pretrained _ALL_MODELS =[ "Qwen/Qwen3-TTS-12Hz-0.6B-Base", "Qwen/Qwen3-TTS-12Hz-1.7B-Base", "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", ] _active_models_env = os.environ.get("ACTIVE_MODELS", "") if _active_models_env: _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()} AVAILABLE_MODELS =[m for m in _ALL_MODELS if m in _allowed] else: AVAILABLE_MODELS = list(_ALL_MODELS) BASE_DIR = Path(__file__).resolve().parent _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/qwen3-tts-assets")) PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt" # Restauradas exactamente las voces originales de clone PRESET_REFS =[ ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"), ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"), ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"), ] _GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main" _PRESET_REMOTE = { "ref_audio": f"{_GITHUB_RAW}/ref_audio.wav", "ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav", "ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav", } _TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt" def _fetch_preset_assets() -> None: """Download preset wav files and transcripts from GitHub if not present locally.""" import urllib.request _ASSET_DIR.mkdir(parents=True, exist_ok=True) PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True) if not PRESET_TRANSCRIPTS.exists(): try: urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS) except Exception as e: print(f"Warning: could not fetch transcripts: {e}") for key, path, _ in PRESET_REFS: if not path.exists() and key in _PRESET_REMOTE: try: urllib.request.urlretrieve(_PRESET_REMOTE[key], path) print(f"Downloaded {path.name}") except Exception as e: print(f"Warning: could not fetch {key}: {e}") _preset_refs: dict[str, dict] = {} def _load_preset_transcripts() -> dict[str, str]: if not PRESET_TRANSCRIPTS.exists(): return {} transcripts = {} for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines(): if ":" not in line: continue key_part, text = line.split(":", 1) key = key_part.split("(")[0].strip() transcripts[key] = text.strip() return transcripts def _load_preset_refs() -> None: transcripts = _load_preset_transcripts() for key, path, label in PRESET_REFS: if not path.exists(): continue content = path.read_bytes() cached_path = _get_cached_ref_path(content) _preset_refs[key] = { "id": key, "label": label, "filename": path.name, "path": cached_path, "ref_text": transcripts.get(key, ""), "audio_b64": base64.b64encode(content).decode(), } app = FastAPI(title="Qwen3-TTS Demo") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _model_cache: OrderedDict[str, Qwen3TTSModel] = OrderedDict() _active_model_name: str | None = None _loading = False _ref_cache: dict[str, str] = {} _ref_cache_lock = threading.Lock() _parakeet = None _generation_lock = asyncio.Lock() _generation_waiters: int = 0 MAX_TEXT_CHARS = 1000 MAX_AUDIO_BYTES = 10 * 1024 * 1024 _AUDIO_TOO_LARGE_MSG = ( "Audio file too large ({size_mb:.1f} MB). " "Voice cloning works best with short clips under 1 minute — please upload a shorter recording." ) # ─── Helpers ────────────────────────────────────────────────────────────────── def _to_wav_b64(audio: np.ndarray, sr: int) -> str: if audio.dtype != np.float32: audio = audio.astype(np.float32) if audio.ndim > 1: audio = audio.squeeze() buf = io.BytesIO() sf.write(buf, audio, sr, format="WAV", subtype="PCM_16") b64 = base64.b64encode(buf.getvalue()).decode() return b64 def _concat_audio(audio_list) -> np.ndarray: if isinstance(audio_list, np.ndarray): return audio_list.astype(np.float32).squeeze() parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0] return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32) def _get_cached_ref_path(content: bytes) -> str: digest = hashlib.sha1(content).hexdigest() with _ref_cache_lock: cached = _ref_cache.get(digest) if cached and os.path.exists(cached): return cached tmp_dir = Path(tempfile.gettempdir()) path = tmp_dir / f"qwen3_tts_ref_{digest}.wav" if not path.exists(): path.write_bytes(content) _ref_cache[digest] = str(path) return str(path) # ─── Routes ─────────────────────────────────────────────────────────────────── _fetch_preset_assets() _load_preset_refs() @app.get("/") async def root(): return FileResponse(Path(__file__).parent / "index.html") @app.post("/transcribe") async def transcribe_audio(audio: UploadFile = File(...)): if _parakeet is None: raise HTTPException(status_code=503, detail="Transcription model not loaded") content = await audio.read() if len(content) > MAX_AUDIO_BYTES: raise HTTPException( status_code=400, detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content) / 1024 / 1024), ) def run(): wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False) if wav.ndim > 1: wav = wav.mean(axis=1) wav_t = torch.from_numpy(wav) if sr != 16000: wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0) return _parakeet.transcribe(wav_t) text = await asyncio.to_thread(run) return {"text": text} @app.get("/status") async def get_status(): speakers =[] model_type = None active = _model_cache.get(_active_model_name) if _active_model_name else None if active is not None: try: model_type = "official" speakers = active.get_supported_speakers() or [] except Exception: speakers =[] return { "loaded": active is not None, "model": _active_model_name, "loading": _loading, "available_models": AVAILABLE_MODELS, "model_type": model_type, "speakers": speakers, "transcription_available": _parakeet is not None, "preset_refs": [ {"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values() ], "queue_depth": _generation_waiters, "cached_models": list(_model_cache.keys()), } @app.get("/preset_ref/{preset_id}") async def get_preset_ref(preset_id: str): preset = _preset_refs.get(preset_id) if not preset: raise HTTPException(status_code=404, detail="Preset not found") return { "id": preset["id"], "label": preset["label"], "filename": preset["filename"], "ref_text": preset["ref_text"], "audio_b64": preset["audio_b64"], } @app.post("/load") async def load_model(model_id: str = Form(...)): global _active_model_name, _loading if model_id in _model_cache: _active_model_name = model_id return {"status": "already_loaded", "model": model_id} _loading = True def _do_load(): global _active_model_name, _loading try: # 🛡️ PROTECCIÓN DE RAM CRÍTICA: # Si hay algún modelo anterior, lo destruimos y forzamos vaciado de RAM if len(_model_cache) > 0: _model_cache.clear() gc.collect() new_model = Qwen3TTSModel.from_pretrained( model_id, device_map="cpu", dtype=torch.float32, ) _model_cache[model_id] = new_model _active_model_name = model_id print(f"Modelo {model_id} cargado exitosamente en CPU.") finally: _loading = False async with _generation_lock: await asyncio.to_thread(_do_load) return {"status": "loaded", "model": model_id} @app.post("/generate/stream") async def generate_stream( text: str = Form(...), language: str = Form("Spanish"), mode: str = Form("voice_clone"), ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""), xvec_only: bool = Form(True), chunk_size: int = Form(8), temperature: float = Form(0.9), top_k: int = Form(50), repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None), ): if not _active_model_name or _active_model_name not in _model_cache: raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.") if len(text) > MAX_TEXT_CHARS: raise HTTPException(status_code=400, detail="Text too long.") tmp_path = None tmp_is_cached = False if ref_preset and ref_preset in _preset_refs: preset = _preset_refs[ref_preset] tmp_path = preset["path"] tmp_is_cached = True if not ref_text: ref_text = preset["ref_text"] elif ref_audio and ref_audio.filename: content = await ref_audio.read() if len(content) > MAX_AUDIO_BYTES: raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG) tmp_path = _get_cached_ref_path(content) tmp_is_cached = True loop = asyncio.get_event_loop() queue: asyncio.Queue[str | None] = asyncio.Queue() def run_generation(): try: model = _model_cache.get(_active_model_name) t0 = time.perf_counter() # En CPU con la librería oficial procesamos todo y enviamos en un solo bloque # para mantener la estabilidad del frontend y evitar "NoneType". if mode == "voice_clone": audio_list, sr = model.generate_voice_clone( text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) elif mode == "custom": if not speaker: raise ValueError("Speaker ID is required") audio_list, sr = model.generate_custom_voice( text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) else: audio_list, sr = model.generate_voice_design( text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) elapsed = time.perf_counter() - t0 chunk_audio = _concat_audio(audio_list) dur = len(chunk_audio) / sr rtf = dur / elapsed if elapsed > 0 else 0.0 ttfa_ms = round(elapsed * 1000) audio_b64 = _to_wav_b64(chunk_audio, sr) payload = { "type": "chunk", "audio_b64": audio_b64, "sample_rate": sr, "ttfa_ms": ttfa_ms, "voice_clone_ms": 0, "rtf": round(rtf, 3), "total_audio_s": round(dur, 3), "elapsed_ms": ttfa_ms } loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload)) done_payload = { "type": "done", "ttfa_ms": ttfa_ms, "voice_clone_ms": 0, "rtf": round(rtf, 3), "total_audio_s": round(dur, 3), "total_ms": ttfa_ms } loop.call_soon_threadsafe(queue.put_nowait, json.dumps(done_payload)) except Exception as e: import traceback err = {"type": "error", "message": str(e), "detail": traceback.format_exc()} loop.call_soon_threadsafe(queue.put_nowait, json.dumps(err)) finally: loop.call_soon_threadsafe(queue.put_nowait, None) if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: os.unlink(tmp_path) async def sse(): global _generation_waiters lock_acquired = False _generation_waiters += 1 people_ahead = _generation_waiters - 1 + (1 if _generation_lock.locked() else 0) try: if people_ahead > 0: yield f"data: {json.dumps({'type': 'queued', 'position': people_ahead})}\n\n" await _generation_lock.acquire() lock_acquired = True _generation_waiters -= 1 thread = threading.Thread(target=run_generation, daemon=True) thread.start() while True: msg = await queue.get() if msg is None: break yield f"data: {msg}\n\n" except asyncio.CancelledError: pass finally: if lock_acquired: _generation_lock.release() else: _generation_waiters -= 1 return StreamingResponse(sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) @app.post("/generate") async def generate_non_streaming( text: str = Form(...), language: str = Form("Spanish"), mode: str = Form("voice_clone"), ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""), xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50), repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None), ): if not _active_model_name or _active_model_name not in _model_cache: raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.") tmp_path = None tmp_is_cached = False if ref_preset and ref_preset in _preset_refs: preset = _preset_refs[ref_preset] tmp_path = preset["path"] tmp_is_cached = True elif ref_audio and ref_audio.filename: content = await ref_audio.read() tmp_path = _get_cached_ref_path(content) tmp_is_cached = True def run(): model = _model_cache.get(_active_model_name) t0 = time.perf_counter() if mode == "voice_clone": audio_list, sr = model.generate_voice_clone( text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) elif mode == "custom": audio_list, sr = model.generate_custom_voice( text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) else: audio_list, sr = model.generate_voice_design( text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" ) elapsed = time.perf_counter() - t0 audio = _concat_audio(audio_list) dur = len(audio) / sr return audio, sr, elapsed, dur global _generation_waiters _generation_waiters += 1 lock_acquired = False try: await _generation_lock.acquire() lock_acquired = True _generation_waiters -= 1 audio, sr, elapsed, dur = await asyncio.to_thread(run) rtf = dur / elapsed if elapsed > 0 else 0.0 return JSONResponse({ "audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}, }) finally: if lock_acquired: _generation_lock.release() else: _generation_waiters -= 1 if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: os.unlink(tmp_path) def main(): parser = argparse.ArgumentParser(description="Qwen3-TTS Demo Server") parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup") parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860))) parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--no-preload", action="store_true", help="Skip model loading at startup") args = parser.parse_args() if not args.no_preload: global _active_model_name, _parakeet print(f"Loading model: {args.model}") _startup_model = Qwen3TTSModel.from_pretrained(args.model, device_map="cpu", dtype=torch.float32) _model_cache[args.model] = _startup_model _active_model_name = args.model print("Loading transcription model (nano-parakeet)…") _parakeet = _parakeet_from_pretrained(device="cpu") print("Transcription model ready on CPU.") print(f"Ready. Open http://localhost:{args.port}") uvicorn.run(app, host=args.host, port=args.port, log_level="info") if __name__ == "__main__": main()