Spaces:
Sleeping
Sleeping
Update server.py
Browse files
server.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Faster Qwen3-TTS Demo Server (CPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
|
@@ -35,44 +40,54 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
| 35 |
# ==============================================================================
|
| 36 |
import site
|
| 37 |
|
| 38 |
-
def
|
| 39 |
-
# 1. Parche físico
|
| 40 |
try:
|
| 41 |
for p in site.getsitepackages():
|
| 42 |
model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
|
| 43 |
if os.path.exists(model_py):
|
| 44 |
with open(model_py, "r") as f: code = f.read()
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
except Exception: pass
|
| 48 |
|
| 49 |
-
# 2.
|
|
|
|
|
|
|
| 50 |
torch.cuda.is_available = lambda: False
|
| 51 |
torch.cuda.current_device = lambda: 0
|
| 52 |
torch.cuda.device_count = lambda: 1
|
| 53 |
-
|
| 54 |
|
| 55 |
-
# 3.
|
| 56 |
torch.Tensor.cuda = lambda self, *args, **kwargs: self
|
| 57 |
torch.nn.Module.cuda = lambda self, *args, **kwargs: self
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
|
| 61 |
if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
|
| 62 |
kwargs['device'] = 'cpu'
|
| 63 |
-
return
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
|
| 72 |
try:
|
| 73 |
from faster_qwen3_tts import FasterQwen3TTS
|
| 74 |
import faster_qwen3_tts.model as fq_model
|
| 75 |
|
|
|
|
| 76 |
class CPU_PredictorGraph:
|
| 77 |
def __init__(self, model, *args, **kwargs):
|
| 78 |
self.model = model
|
|
@@ -84,12 +99,13 @@ try:
|
|
| 84 |
|
| 85 |
fq_model.PredictorGraph = CPU_PredictorGraph
|
| 86 |
except ImportError:
|
|
|
|
| 87 |
sys.exit(1)
|
| 88 |
# ==============================================================================
|
| 89 |
|
| 90 |
from nano_parakeet import from_pretrained as _parakeet_from_pretrained
|
| 91 |
|
| 92 |
-
_ALL_MODELS =
|
| 93 |
"Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
| 94 |
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
| 95 |
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
|
@@ -97,151 +113,343 @@ _ALL_MODELS = [
|
|
| 97 |
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
|
| 98 |
]
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
|
|
|
|
| 104 |
_ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
|
| 105 |
-
|
|
|
|
| 106 |
("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
|
| 107 |
("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
|
| 108 |
("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
|
| 109 |
]
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
app = FastAPI(title="Faster Qwen3-TTS
|
| 120 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 121 |
|
| 122 |
-
_model_cache: OrderedDict = OrderedDict()
|
|
|
|
| 123 |
_active_model_name: str | None = None
|
| 124 |
_loading = False
|
|
|
|
|
|
|
| 125 |
_parakeet = None
|
| 126 |
_generation_lock = asyncio.Lock()
|
| 127 |
-
_generation_waiters = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
@app.get("/")
|
| 130 |
async def root(): return FileResponse(Path(__file__).parent / "index.html")
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
@app.get("/status")
|
| 133 |
async def get_status():
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return {
|
| 137 |
"loaded": active is not None, "model": _active_model_name, "loading": _loading,
|
| 138 |
-
"available_models": AVAILABLE_MODELS, "speakers": speakers,
|
| 139 |
-
"
|
|
|
|
|
|
|
| 140 |
}
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
@app.post("/load")
|
| 143 |
async def load_model(model_id: str = Form(...)):
|
| 144 |
global _active_model_name, _loading
|
| 145 |
if model_id in _model_cache:
|
| 146 |
_active_model_name = model_id
|
| 147 |
-
|
|
|
|
| 148 |
_loading = True
|
| 149 |
-
def
|
| 150 |
global _active_model_name, _loading
|
| 151 |
try:
|
| 152 |
-
if len(_model_cache) >=
|
| 153 |
-
|
| 154 |
-
_model_cache[model_id] =
|
|
|
|
| 155 |
_active_model_name = model_id
|
|
|
|
| 156 |
finally: _loading = False
|
| 157 |
-
async with _generation_lock: await asyncio.to_thread(
|
| 158 |
-
return {"status": "loaded"}
|
| 159 |
|
| 160 |
@app.post("/generate/stream")
|
| 161 |
async def generate_stream(
|
| 162 |
-
text: str = Form(...), mode: str = Form("voice_clone"),
|
|
|
|
|
|
|
|
|
|
| 163 |
ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
|
| 164 |
-
chunk_size: int = Form(8), temperature: float = Form(0.9)
|
| 165 |
):
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
tmp_path = None
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
loop = asyncio.get_event_loop()
|
| 174 |
queue = asyncio.Queue()
|
| 175 |
|
| 176 |
-
def
|
| 177 |
try:
|
|
|
|
| 178 |
t0 = time.perf_counter()
|
| 179 |
total_audio_s = 0.0
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
ttfa_ms, total_gen_ms = None, 0.0
|
|
|
|
| 189 |
for chunk, sr, timing in gen:
|
| 190 |
-
# 🛡️ PROTECCIÓN ANTI-NONE
|
| 191 |
timing = timing or {}
|
| 192 |
-
prefill = timing.get('prefill_ms')
|
| 193 |
-
decode = timing.get('decode_ms')
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if ttfa_ms is None: ttfa_ms = total_gen_ms
|
| 197 |
|
| 198 |
-
chunk_audio =
|
| 199 |
total_audio_s += len(chunk_audio) / sr
|
| 200 |
-
|
| 201 |
-
# RTF Safe
|
| 202 |
rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
|
| 203 |
|
| 204 |
-
buf = io.BytesIO()
|
| 205 |
-
sf.write(buf, chunk_audio.astype(np.float32), sr, format="WAV", subtype="PCM_16")
|
| 206 |
-
|
| 207 |
payload = {
|
| 208 |
-
"type": "chunk", "audio_b64":
|
| 209 |
-
"
|
| 210 |
-
"total_audio_s": round(total_audio_s, 3)
|
|
|
|
| 211 |
}
|
| 212 |
loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
|
| 213 |
|
| 214 |
-
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
except Exception as e:
|
| 216 |
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
|
| 217 |
finally:
|
| 218 |
loop.call_soon_threadsafe(queue.put_nowait, None)
|
|
|
|
| 219 |
|
| 220 |
async def sse():
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
while True:
|
| 224 |
msg = await queue.get()
|
| 225 |
if msg is None: break
|
| 226 |
yield f"data: {msg}\n\n"
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
return StreamingResponse(sse(), media_type="text/event-stream")
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
_model_cache[args.model] = m
|
| 240 |
-
global _active_model_name, _parakeet
|
| 241 |
-
_active_model_name = args.model
|
| 242 |
-
_parakeet = _parakeet_from_pretrained(device="cpu")
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
if __name__ == "__main__":
|
| 247 |
main()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Faster Qwen3-TTS Demo Server (CPU Optimizado + Parches Anti-CUDA y Anti-None)
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python demo/server.py
|
| 7 |
+
python demo/server.py --model Qwen/Qwen3-TTS-12Hz-1.7B-Base --port 7860
|
| 8 |
+
python demo/server.py --no-preload # skip startup model load
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
|
|
|
| 40 |
# ==============================================================================
|
| 41 |
import site
|
| 42 |
|
| 43 |
+
def _apply_anti_cuda_shield():
|
| 44 |
+
# 1. Parche físico para el ValueError de la librería
|
| 45 |
try:
|
| 46 |
for p in site.getsitepackages():
|
| 47 |
model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
|
| 48 |
if os.path.exists(model_py):
|
| 49 |
with open(model_py, "r") as f: code = f.read()
|
| 50 |
+
if 'raise ValueError("CUDA graphs require CUDA device")' in code:
|
| 51 |
+
code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
|
| 52 |
+
with open(model_py, "w") as f: f.write(code)
|
| 53 |
except Exception: pass
|
| 54 |
|
| 55 |
+
# 2. Neutralizar validaciones internas de CUDA
|
| 56 |
+
if hasattr(torch.cuda, '_lazy_init'):
|
| 57 |
+
torch.cuda._lazy_init = lambda *args, **kwargs: None
|
| 58 |
torch.cuda.is_available = lambda: False
|
| 59 |
torch.cuda.current_device = lambda: 0
|
| 60 |
torch.cuda.device_count = lambda: 1
|
| 61 |
+
torch.cuda.get_device_name = lambda x: "CPU"
|
| 62 |
|
| 63 |
+
# 3. Interceptar .cuda()
|
| 64 |
torch.Tensor.cuda = lambda self, *args, **kwargs: self
|
| 65 |
torch.nn.Module.cuda = lambda self, *args, **kwargs: self
|
| 66 |
|
| 67 |
+
# 4. Interceptar y redirigir .to('cuda') hacia .to('cpu')
|
| 68 |
+
_orig_tensor_to = torch.Tensor.to
|
| 69 |
+
def _tensor_to_mock(self, *args, **kwargs):
|
| 70 |
new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
|
| 71 |
if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
|
| 72 |
kwargs['device'] = 'cpu'
|
| 73 |
+
return _orig_tensor_to(self, *new_args, **kwargs)
|
| 74 |
+
torch.Tensor.to = _tensor_to_mock
|
| 75 |
|
| 76 |
+
_orig_module_to = torch.nn.Module.to
|
| 77 |
+
def _module_to_mock(self, *args, **kwargs):
|
| 78 |
+
new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
|
| 79 |
+
if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
|
| 80 |
+
kwargs['device'] = 'cpu'
|
| 81 |
+
return _orig_module_to(self, *new_args, **kwargs)
|
| 82 |
+
torch.nn.Module.to = _module_to_mock
|
| 83 |
|
| 84 |
+
_apply_anti_cuda_shield()
|
| 85 |
|
| 86 |
try:
|
| 87 |
from faster_qwen3_tts import FasterQwen3TTS
|
| 88 |
import faster_qwen3_tts.model as fq_model
|
| 89 |
|
| 90 |
+
# Clon del PredictorGraph para CPU
|
| 91 |
class CPU_PredictorGraph:
|
| 92 |
def __init__(self, model, *args, **kwargs):
|
| 93 |
self.model = model
|
|
|
|
| 99 |
|
| 100 |
fq_model.PredictorGraph = CPU_PredictorGraph
|
| 101 |
except ImportError:
|
| 102 |
+
print("Error: faster_qwen3_tts not found.")
|
| 103 |
sys.exit(1)
|
| 104 |
# ==============================================================================
|
| 105 |
|
| 106 |
from nano_parakeet import from_pretrained as _parakeet_from_pretrained
|
| 107 |
|
| 108 |
+
_ALL_MODELS =[
|
| 109 |
"Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
| 110 |
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
| 111 |
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
|
|
|
| 113 |
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
|
| 114 |
]
|
| 115 |
|
| 116 |
+
_active_models_env = os.environ.get("ACTIVE_MODELS", "")
|
| 117 |
+
if _active_models_env:
|
| 118 |
+
_allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
|
| 119 |
+
AVAILABLE_MODELS = [m for m in _ALL_MODELS if m in _allowed]
|
| 120 |
+
else:
|
| 121 |
+
AVAILABLE_MODELS = list(_ALL_MODELS)
|
| 122 |
|
| 123 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 124 |
_ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
|
| 125 |
+
PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
|
| 126 |
+
PRESET_REFS =[
|
| 127 |
("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
|
| 128 |
("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
|
| 129 |
("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
|
| 130 |
]
|
| 131 |
|
| 132 |
+
_GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main"
|
| 133 |
+
_PRESET_REMOTE = {
|
| 134 |
+
"ref_audio": f"{_GITHUB_RAW}/ref_audio.wav",
|
| 135 |
+
"ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav",
|
| 136 |
+
"ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav",
|
| 137 |
+
}
|
| 138 |
+
_TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt"
|
| 139 |
+
|
| 140 |
+
def _fetch_preset_assets() -> None:
|
| 141 |
+
import urllib.request
|
| 142 |
+
_ASSET_DIR.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
if not PRESET_TRANSCRIPTS.exists():
|
| 145 |
+
try: urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS)
|
| 146 |
+
except Exception: pass
|
| 147 |
+
for key, path, _ in PRESET_REFS:
|
| 148 |
+
if not path.exists() and key in _PRESET_REMOTE:
|
| 149 |
+
try: urllib.request.urlretrieve(_PRESET_REMOTE[key], path)
|
| 150 |
+
except Exception: pass
|
| 151 |
+
|
| 152 |
+
_preset_refs: dict[str, dict] = {}
|
| 153 |
+
|
| 154 |
+
def _load_preset_transcripts() -> dict[str, str]:
|
| 155 |
+
if not PRESET_TRANSCRIPTS.exists(): return {}
|
| 156 |
+
transcripts = {}
|
| 157 |
+
for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines():
|
| 158 |
+
if ":" not in line: continue
|
| 159 |
+
key_part, text = line.split(":", 1)
|
| 160 |
+
key = key_part.split("(")[0].strip()
|
| 161 |
+
transcripts[key] = text.strip()
|
| 162 |
+
return transcripts
|
| 163 |
+
|
| 164 |
+
def _load_preset_refs() -> None:
|
| 165 |
+
transcripts = _load_preset_transcripts()
|
| 166 |
+
for key, path, label in PRESET_REFS:
|
| 167 |
+
if not path.exists(): continue
|
| 168 |
+
content = path.read_bytes()
|
| 169 |
+
cached_path = _get_cached_ref_path(content)
|
| 170 |
+
_preset_refs[key] = {
|
| 171 |
+
"id": key,
|
| 172 |
+
"label": label,
|
| 173 |
+
"filename": path.name,
|
| 174 |
+
"path": cached_path,
|
| 175 |
+
"ref_text": transcripts.get(key, ""),
|
| 176 |
+
"audio_b64": base64.b64encode(content).decode(),
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
def _prime_preset_voice_cache(model: FasterQwen3TTS) -> None:
|
| 180 |
+
if not _preset_refs: return
|
| 181 |
+
for preset in _preset_refs.values():
|
| 182 |
+
try:
|
| 183 |
+
model._prepare_generation(
|
| 184 |
+
text="Hello.", ref_audio=preset["path"], ref_text=preset["ref_text"],
|
| 185 |
+
language="English", xvec_only=True, non_streaming_mode=True,
|
| 186 |
+
)
|
| 187 |
+
except Exception: continue
|
| 188 |
|
| 189 |
+
app = FastAPI(title="Faster Qwen3-TTS Demo")
|
| 190 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 191 |
|
| 192 |
+
_model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict()
|
| 193 |
+
_model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "1"))
|
| 194 |
_active_model_name: str | None = None
|
| 195 |
_loading = False
|
| 196 |
+
_ref_cache: dict[str, str] = {}
|
| 197 |
+
_ref_cache_lock = threading.Lock()
|
| 198 |
_parakeet = None
|
| 199 |
_generation_lock = asyncio.Lock()
|
| 200 |
+
_generation_waiters: int = 0
|
| 201 |
+
|
| 202 |
+
MAX_TEXT_CHARS = 1000
|
| 203 |
+
MAX_AUDIO_BYTES = 10 * 1024 * 1024
|
| 204 |
+
_AUDIO_TOO_LARGE_MSG = "Audio file too large. Please upload a shorter recording."
|
| 205 |
+
|
| 206 |
+
def _to_wav_b64(audio: np.ndarray, sr: int) -> str:
|
| 207 |
+
if audio.dtype != np.float32: audio = audio.astype(np.float32)
|
| 208 |
+
if audio.ndim > 1: audio = audio.squeeze()
|
| 209 |
+
buf = io.BytesIO()
|
| 210 |
+
sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
|
| 211 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 212 |
+
|
| 213 |
+
def _concat_audio(audio_list) -> np.ndarray:
|
| 214 |
+
if isinstance(audio_list, np.ndarray): return audio_list.astype(np.float32).squeeze()
|
| 215 |
+
parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0]
|
| 216 |
+
return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32)
|
| 217 |
+
|
| 218 |
+
def _get_cached_ref_path(content: bytes) -> str:
|
| 219 |
+
digest = hashlib.sha1(content).hexdigest()
|
| 220 |
+
with _ref_cache_lock:
|
| 221 |
+
cached = _ref_cache.get(digest)
|
| 222 |
+
if cached and os.path.exists(cached): return cached
|
| 223 |
+
path = Path(tempfile.gettempdir()) / f"faster_qwen3_tts_ref_{digest}.wav"
|
| 224 |
+
if not path.exists(): path.write_bytes(content)
|
| 225 |
+
_ref_cache[digest] = str(path)
|
| 226 |
+
return str(path)
|
| 227 |
+
|
| 228 |
+
_fetch_preset_assets()
|
| 229 |
+
_load_preset_refs()
|
| 230 |
|
| 231 |
@app.get("/")
|
| 232 |
async def root(): return FileResponse(Path(__file__).parent / "index.html")
|
| 233 |
|
| 234 |
+
@app.post("/transcribe")
|
| 235 |
+
async def transcribe_audio(audio: UploadFile = File(...)):
|
| 236 |
+
if _parakeet is None: raise HTTPException(status_code=503, detail="Transcription model not loaded")
|
| 237 |
+
content = await audio.read()
|
| 238 |
+
if len(content) > MAX_AUDIO_BYTES: raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG)
|
| 239 |
+
def run():
|
| 240 |
+
wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False)
|
| 241 |
+
if wav.ndim > 1: wav = wav.mean(axis=1)
|
| 242 |
+
wav_t = torch.from_numpy(wav)
|
| 243 |
+
if sr != 16000: wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0)
|
| 244 |
+
return _parakeet.transcribe(wav_t)
|
| 245 |
+
return {"text": await asyncio.to_thread(run)}
|
| 246 |
+
|
| 247 |
@app.get("/status")
|
| 248 |
async def get_status():
|
| 249 |
+
speakers =[]
|
| 250 |
+
model_type = None
|
| 251 |
+
active = _model_cache.get(_active_model_name) if _active_model_name else None
|
| 252 |
+
if active is not None:
|
| 253 |
+
try:
|
| 254 |
+
model_type = active.model.model.tts_model_type
|
| 255 |
+
speakers = active.model.get_supported_speakers() or[]
|
| 256 |
+
except Exception: pass
|
| 257 |
return {
|
| 258 |
"loaded": active is not None, "model": _active_model_name, "loading": _loading,
|
| 259 |
+
"available_models": AVAILABLE_MODELS, "model_type": model_type, "speakers": speakers,
|
| 260 |
+
"transcription_available": _parakeet is not None,
|
| 261 |
+
"preset_refs": [{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values()],
|
| 262 |
+
"queue_depth": _generation_waiters, "cached_models": list(_model_cache.keys()),
|
| 263 |
}
|
| 264 |
|
| 265 |
+
@app.get("/preset_ref/{preset_id}")
|
| 266 |
+
async def get_preset_ref(preset_id: str):
|
| 267 |
+
preset = _preset_refs.get(preset_id)
|
| 268 |
+
if not preset: raise HTTPException(status_code=404, detail="Preset not found")
|
| 269 |
+
return preset
|
| 270 |
+
|
| 271 |
@app.post("/load")
|
| 272 |
async def load_model(model_id: str = Form(...)):
|
| 273 |
global _active_model_name, _loading
|
| 274 |
if model_id in _model_cache:
|
| 275 |
_active_model_name = model_id
|
| 276 |
+
_model_cache.move_to_end(model_id)
|
| 277 |
+
return {"status": "already_loaded", "model": model_id}
|
| 278 |
_loading = True
|
| 279 |
+
def _do_load():
|
| 280 |
global _active_model_name, _loading
|
| 281 |
try:
|
| 282 |
+
if len(_model_cache) >= _model_cache_max: _model_cache.popitem(last=False)
|
| 283 |
+
new_model = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
|
| 284 |
+
_model_cache[model_id] = new_model
|
| 285 |
+
_model_cache.move_to_end(model_id)
|
| 286 |
_active_model_name = model_id
|
| 287 |
+
_prime_preset_voice_cache(new_model)
|
| 288 |
finally: _loading = False
|
| 289 |
+
async with _generation_lock: await asyncio.to_thread(_do_load)
|
| 290 |
+
return {"status": "loaded", "model": model_id}
|
| 291 |
|
| 292 |
@app.post("/generate/stream")
|
| 293 |
async def generate_stream(
|
| 294 |
+
text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
|
| 295 |
+
ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
|
| 296 |
+
xvec_only: bool = Form(True), chunk_size: int = Form(8), temperature: float = Form(0.9),
|
| 297 |
+
top_k: int = Form(50), repetition_penalty: float = Form(1.05),
|
| 298 |
ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
|
|
|
|
| 299 |
):
|
| 300 |
+
if not _active_model_name or _active_model_name not in _model_cache:
|
| 301 |
+
raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.")
|
| 302 |
+
if len(text) > MAX_TEXT_CHARS: raise HTTPException(status_code=400, detail="Text too long.")
|
| 303 |
|
| 304 |
tmp_path = None
|
| 305 |
+
tmp_is_cached = False
|
| 306 |
+
if ref_preset and ref_preset in _preset_refs:
|
| 307 |
+
preset = _preset_refs[ref_preset]
|
| 308 |
+
tmp_path, tmp_is_cached = preset["path"], True
|
| 309 |
+
if not ref_text: ref_text = preset["ref_text"]
|
| 310 |
+
elif ref_audio and ref_audio.filename:
|
| 311 |
+
content = await ref_audio.read()
|
| 312 |
+
if len(content) > MAX_AUDIO_BYTES: raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG)
|
| 313 |
+
tmp_path, tmp_is_cached = _get_cached_ref_path(content), True
|
| 314 |
|
| 315 |
loop = asyncio.get_event_loop()
|
| 316 |
queue = asyncio.Queue()
|
| 317 |
|
| 318 |
+
def run_generation():
|
| 319 |
try:
|
| 320 |
+
model = _model_cache.get(_active_model_name)
|
| 321 |
t0 = time.perf_counter()
|
| 322 |
total_audio_s = 0.0
|
| 323 |
+
voice_clone_ms = 0.0
|
| 324 |
|
| 325 |
+
if mode == "voice_clone":
|
| 326 |
+
gen = model.generate_voice_clone_streaming(
|
| 327 |
+
text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
|
| 328 |
+
xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature,
|
| 329 |
+
top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
|
| 330 |
+
)
|
| 331 |
+
elif mode == "custom":
|
| 332 |
+
gen = model.generate_custom_voice_streaming(
|
| 333 |
+
text=text, speaker=speaker, language=language, instruct=instruct,
|
| 334 |
+
chunk_size=chunk_size, temperature=temperature, top_k=top_k,
|
| 335 |
+
repetition_penalty=repetition_penalty, max_new_tokens=360
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
gen = model.generate_voice_design_streaming(
|
| 339 |
+
text=text, instruct=instruct, language=language, chunk_size=chunk_size,
|
| 340 |
+
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
|
| 341 |
+
)
|
| 342 |
|
| 343 |
ttfa_ms, total_gen_ms = None, 0.0
|
| 344 |
+
|
| 345 |
for chunk, sr, timing in gen:
|
| 346 |
+
# 🛡️ PROTECCIÓN ANTI-NONE APLICADA
|
| 347 |
timing = timing or {}
|
| 348 |
+
prefill = timing.get('prefill_ms')
|
| 349 |
+
decode = timing.get('decode_ms')
|
| 350 |
|
| 351 |
+
# Convertimos a float de forma segura (0.0 si es None)
|
| 352 |
+
prefill_val = float(prefill) if prefill is not None else 0.0
|
| 353 |
+
decode_val = float(decode) if decode is not None else 0.0
|
| 354 |
+
|
| 355 |
+
total_gen_ms += (prefill_val + decode_val)
|
| 356 |
if ttfa_ms is None: ttfa_ms = total_gen_ms
|
| 357 |
|
| 358 |
+
chunk_audio = _concat_audio(chunk)
|
| 359 |
total_audio_s += len(chunk_audio) / sr
|
|
|
|
|
|
|
| 360 |
rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
|
| 361 |
|
|
|
|
|
|
|
|
|
|
| 362 |
payload = {
|
| 363 |
+
"type": "chunk", "audio_b64": _to_wav_b64(chunk_audio, sr), "sample_rate": sr,
|
| 364 |
+
"ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms),
|
| 365 |
+
"rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
|
| 366 |
+
"elapsed_ms": round((time.perf_counter() - t0) * 1000, 3)
|
| 367 |
}
|
| 368 |
loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
|
| 369 |
|
| 370 |
+
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
|
| 371 |
+
"type": "done", "ttfa_ms": round(ttfa_ms or 0), "voice_clone_ms": round(voice_clone_ms),
|
| 372 |
+
"rtf": round(rtf, 3) if 'rtf' in locals() else 0.0,
|
| 373 |
+
"total_audio_s": round(total_audio_s, 3), "total_ms": round((time.perf_counter() - t0) * 1000)
|
| 374 |
+
}))
|
| 375 |
except Exception as e:
|
| 376 |
loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
|
| 377 |
finally:
|
| 378 |
loop.call_soon_threadsafe(queue.put_nowait, None)
|
| 379 |
+
if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: os.unlink(tmp_path)
|
| 380 |
|
| 381 |
async def sse():
|
| 382 |
+
global _generation_waiters
|
| 383 |
+
_generation_waiters += 1
|
| 384 |
+
lock_acquired = False
|
| 385 |
+
try:
|
| 386 |
+
await _generation_lock.acquire()
|
| 387 |
+
lock_acquired = True
|
| 388 |
+
_generation_waiters -= 1
|
| 389 |
+
threading.Thread(target=run_generation, daemon=True).start()
|
| 390 |
while True:
|
| 391 |
msg = await queue.get()
|
| 392 |
if msg is None: break
|
| 393 |
yield f"data: {msg}\n\n"
|
| 394 |
+
finally:
|
| 395 |
+
if lock_acquired: _generation_lock.release()
|
| 396 |
+
else: _generation_waiters -= 1
|
| 397 |
|
| 398 |
+
return StreamingResponse(sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
|
| 399 |
|
| 400 |
+
@app.post("/generate")
|
| 401 |
+
async def generate_non_streaming(
|
| 402 |
+
text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
|
| 403 |
+
ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
|
| 404 |
+
xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50),
|
| 405 |
+
repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
|
| 406 |
+
):
|
| 407 |
+
model = _model_cache.get(_active_model_name)
|
| 408 |
+
if not model: raise HTTPException(status_code=400, detail="Model not loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
tmp_path = None
|
| 411 |
+
if ref_preset and ref_preset in _preset_refs: tmp_path = _preset_refs[ref_preset]["path"]
|
| 412 |
+
elif ref_audio: tmp_path = _get_cached_ref_path(await ref_audio.read())
|
| 413 |
+
|
| 414 |
+
def run():
|
| 415 |
+
t0 = time.perf_counter()
|
| 416 |
+
if mode == "voice_clone":
|
| 417 |
+
audio_list, sr = model.generate_voice_clone(text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, xvec_only=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
|
| 418 |
+
elif mode == "custom":
|
| 419 |
+
audio_list, sr = model.generate_custom_voice(text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
|
| 420 |
+
else:
|
| 421 |
+
audio_list, sr = model.generate_voice_design(text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
|
| 422 |
+
elapsed = time.perf_counter() - t0
|
| 423 |
+
audio = _concat_audio(audio_list)
|
| 424 |
+
return audio, sr, elapsed, len(audio)/sr
|
| 425 |
+
|
| 426 |
+
async with _generation_lock:
|
| 427 |
+
audio, sr, elapsed, dur = await asyncio.to_thread(run)
|
| 428 |
+
rtf = dur / elapsed if elapsed > 0 else 0.0
|
| 429 |
+
return JSONResponse({"audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}})
|
| 430 |
+
|
| 431 |
+
def main():
|
| 432 |
+
parser = argparse.ArgumentParser(description="Faster Qwen3-TTS Demo Server")
|
| 433 |
+
parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup")
|
| 434 |
+
parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
|
| 435 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 436 |
+
parser.add_argument("--no-preload", action="store_true", help="Skip model loading at startup")
|
| 437 |
+
args = parser.parse_args()
|
| 438 |
+
|
| 439 |
+
if not args.no_preload:
|
| 440 |
+
global _active_model_name, _parakeet
|
| 441 |
+
print(f"Loading model: {args.model}")
|
| 442 |
+
_startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
|
| 443 |
+
_model_cache[args.model] = _startup_model
|
| 444 |
+
_active_model_name = args.model
|
| 445 |
+
_prime_preset_voice_cache(_startup_model)
|
| 446 |
+
|
| 447 |
+
print("Loading transcription model (nano-parakeet)…")
|
| 448 |
+
_parakeet = _parakeet_from_pretrained(device="cpu")
|
| 449 |
+
print("Transcription model ready on CPU.")
|
| 450 |
+
print(f"Ready. Open http://localhost:{args.port}")
|
| 451 |
+
|
| 452 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
| 453 |
|
| 454 |
if __name__ == "__main__":
|
| 455 |
main()
|