userhugginggit commited on
Commit
16dfc68
·
verified ·
1 Parent(s): f2be83b

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +92 -277
server.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- Faster Qwen3-TTS Demo Server (Forzado Absoluto a CPU)
4
  """
5
 
6
  import argparse
@@ -26,59 +26,49 @@ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
26
  from fastapi.middleware.cors import CORSMiddleware
27
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
28
 
29
- # OPTIMIZACIÓN CPU: Limita el uso excesivo de hilos
30
  torch.set_num_threads(4)
31
  sys.path.insert(0, str(Path(__file__).parent.parent))
32
 
33
  # ==============================================================================
34
- # 🛡️ ESCUDO ANTI-CUDA (FORZAR CPU A NIVEL GLOBAL)
35
  # ==============================================================================
36
  import site
37
 
38
- def _apply_anti_cuda_shield():
39
- # 1. Eliminar bloqueo físico de ValueError en la librería original
40
  try:
41
  for p in site.getsitepackages():
42
  model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
43
  if os.path.exists(model_py):
44
  with open(model_py, "r") as f: code = f.read()
45
- if 'raise ValueError("CUDA graphs require CUDA device")' in code:
46
- code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
47
- with open(model_py, "w") as f: f.write(code)
48
  except Exception: pass
49
 
50
- # 2. Neutralizar las alertas internas de compilación CUDA de PyTorch
51
- if hasattr(torch.cuda, '_lazy_init'):
52
- torch.cuda._lazy_init = lambda *args, **kwargs: None
53
  torch.cuda.is_available = lambda: False
54
  torch.cuda.current_device = lambda: 0
55
  torch.cuda.device_count = lambda: 1
56
- torch.cuda.get_device_name = lambda x: "CPU"
57
 
58
- # 3. Interceptar llamadas directas .cuda() en Tensors y Models
59
  torch.Tensor.cuda = lambda self, *args, **kwargs: self
60
  torch.nn.Module.cuda = lambda self, *args, **kwargs: self
61
 
62
- # 4. Interceptar y redirigir comandos .to('cuda') hacia .to('cpu')
63
- _orig_tensor_to = torch.Tensor.to
64
- def _tensor_to_mock(self, *args, **kwargs):
65
  new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
66
  if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
67
  kwargs['device'] = 'cpu'
68
- return _orig_tensor_to(self, *new_args, **kwargs)
69
- torch.Tensor.to = _tensor_to_mock
70
 
71
- _orig_module_to = torch.nn.Module.to
72
- def _module_to_mock(self, *args, **kwargs):
73
- new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
74
- if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
75
- kwargs['device'] = 'cpu'
76
- return _orig_module_to(self, *new_args, **kwargs)
77
- torch.nn.Module.to = _module_to_mock
78
 
79
- _apply_anti_cuda_shield()
80
 
81
- # 5. Aplicar clon de PredictorGraph
82
  try:
83
  from faster_qwen3_tts import FasterQwen3TTS
84
  import faster_qwen3_tts.model as fq_model
@@ -86,7 +76,7 @@ try:
86
  class CPU_PredictorGraph:
87
  def __init__(self, model, *args, **kwargs):
88
  self.model = model
89
- self.device = "cpu"
90
  def __call__(self, *args, **kwargs): return self.model(*args, **kwargs)
91
  def forward(self, *args, **kwargs): return self.model(*args, **kwargs)
92
  def warmup(self, *args, **kwargs): pass
@@ -94,13 +84,12 @@ try:
94
 
95
  fq_model.PredictorGraph = CPU_PredictorGraph
96
  except ImportError:
97
- print("Error: faster_qwen3_tts not found.")
98
  sys.exit(1)
99
  # ==============================================================================
100
 
101
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
102
 
103
- _ALL_MODELS =[
104
  "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
105
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
106
  "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
@@ -108,248 +97,117 @@ _ALL_MODELS =[
108
  "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
109
  ]
110
 
111
- _active_models_env = os.environ.get("ACTIVE_MODELS", "")
112
- if _active_models_env:
113
- _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
114
- AVAILABLE_MODELS =[m for m in _ALL_MODELS if m in _allowed]
115
- else:
116
- AVAILABLE_MODELS = list(_ALL_MODELS)
117
 
118
  _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
119
- PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
120
- PRESET_REFS =[
121
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
122
  ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
123
  ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
124
  ]
125
 
126
- _GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main"
127
- _PRESET_REMOTE = {
128
- "ref_audio": f"{_GITHUB_RAW}/ref_audio.wav",
129
- "ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav",
130
- "ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav",
131
- }
132
- _TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt"
133
-
134
- def _fetch_preset_assets() -> None:
135
- import urllib.request
136
- _ASSET_DIR.mkdir(parents=True, exist_ok=True)
137
- PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True)
138
- if not PRESET_TRANSCRIPTS.exists():
139
- try:
140
- urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS)
141
- except Exception: pass
142
- for key, path, _ in PRESET_REFS:
143
- if not path.exists() and key in _PRESET_REMOTE:
144
- try:
145
- urllib.request.urlretrieve(_PRESET_REMOTE[key], path)
146
- except Exception: pass
147
-
148
- _preset_refs: dict[str, dict] = {}
149
-
150
- def _load_preset_transcripts() -> dict[str, str]:
151
- if not PRESET_TRANSCRIPTS.exists(): return {}
152
- transcripts = {}
153
- for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines():
154
- if ":" not in line: continue
155
- key_part, text = line.split(":", 1)
156
- key = key_part.split("(")[0].strip()
157
- transcripts[key] = text.strip()
158
- return transcripts
159
-
160
- def _load_preset_refs() -> None:
161
- transcripts = _load_preset_transcripts()
162
- for key, path, label in PRESET_REFS:
163
- if not path.exists(): continue
164
- content = path.read_bytes()
165
- cached_path = _get_cached_ref_path(content)
166
- _preset_refs[key] = {
167
- "id": key,
168
- "label": label,
169
- "filename": path.name,
170
- "path": cached_path,
171
- "ref_text": transcripts.get(key, ""),
172
- "audio_b64": base64.b64encode(content).decode(),
173
- }
174
-
175
- def _prime_preset_voice_cache(model: FasterQwen3TTS) -> None:
176
- if not _preset_refs: return
177
- for preset in _preset_refs.values():
178
- try:
179
- model._prepare_generation(
180
- text="Hello.",
181
- ref_audio=preset["path"],
182
- ref_text=preset["ref_text"],
183
- language="English",
184
- xvec_only=True,
185
- non_streaming_mode=True,
186
- )
187
- except Exception:
188
- continue
189
 
190
- app = FastAPI(title="Faster Qwen3-TTS Demo (CPU)")
 
 
 
 
 
 
191
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
192
 
193
- _model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict()
194
- _model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "1"))
195
  _active_model_name: str | None = None
196
  _loading = False
197
- _ref_cache: dict[str, str] = {}
198
- _ref_cache_lock = threading.Lock()
199
  _parakeet = None
200
  _generation_lock = asyncio.Lock()
201
- _generation_waiters: int = 0
202
-
203
- MAX_TEXT_CHARS = 1000
204
- MAX_AUDIO_BYTES = 10 * 1024 * 1024
205
-
206
- def _to_wav_b64(audio: np.ndarray, sr: int) -> str:
207
- buf = io.BytesIO()
208
- sf.write(buf, audio.astype(np.float32).squeeze(), sr, format="WAV", subtype="PCM_16")
209
- return base64.b64encode(buf.getvalue()).decode()
210
-
211
- def _concat_audio(audio_list) -> np.ndarray:
212
- if isinstance(audio_list, np.ndarray): return audio_list.astype(np.float32).squeeze()
213
- parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0]
214
- return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32)
215
-
216
- def _get_cached_ref_path(content: bytes) -> str:
217
- digest = hashlib.sha1(content).hexdigest()
218
- with _ref_cache_lock:
219
- cached = _ref_cache.get(digest)
220
- if cached and os.path.exists(cached): return cached
221
- path = Path(tempfile.gettempdir()) / f"faster_qwen3_tts_ref_{digest}.wav"
222
- if not path.exists(): path.write_bytes(content)
223
- _ref_cache[digest] = str(path)
224
- return str(path)
225
-
226
- _fetch_preset_assets()
227
- _load_preset_refs()
228
 
229
  @app.get("/")
230
  async def root(): return FileResponse(Path(__file__).parent / "index.html")
231
 
232
- @app.post("/transcribe")
233
- async def transcribe_audio(audio: UploadFile = File(...)):
234
- if _parakeet is None: raise HTTPException(status_code=503, detail="Transcription model not loaded")
235
- content = await audio.read()
236
- def run():
237
- wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False)
238
- if wav.ndim > 1: wav = wav.mean(axis=1)
239
- wav_t = torch.from_numpy(wav)
240
- if sr != 16000: wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0)
241
- return _parakeet.transcribe(wav_t)
242
- return {"text": await asyncio.to_thread(run)}
243
-
244
  @app.get("/status")
245
  async def get_status():
246
- speakers =[]
247
- model_type = None
248
- active = _model_cache.get(_active_model_name) if _active_model_name else None
249
- if active is not None:
250
- try:
251
- model_type = active.model.model.tts_model_type
252
- speakers = active.model.get_supported_speakers() or[]
253
- except Exception: pass
254
  return {
255
- "loaded": active is not None,
256
- "model": _active_model_name,
257
- "loading": _loading,
258
- "available_models": AVAILABLE_MODELS,
259
- "model_type": model_type,
260
- "speakers": speakers,
261
- "transcription_available": _parakeet is not None,
262
- "preset_refs": [{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values()],
263
- "queue_depth": _generation_waiters,
264
- "cached_models": list(_model_cache.keys()),
265
  }
266
 
267
- @app.get("/preset_ref/{preset_id}")
268
- async def get_preset_ref(preset_id: str):
269
- preset = _preset_refs.get(preset_id)
270
- if not preset: raise HTTPException(status_code=404, detail="Preset not found")
271
- return preset
272
-
273
  @app.post("/load")
274
  async def load_model(model_id: str = Form(...)):
275
  global _active_model_name, _loading
276
  if model_id in _model_cache:
277
  _active_model_name = model_id
278
- _model_cache.move_to_end(model_id)
279
- return {"status": "already_loaded", "model": model_id}
280
  _loading = True
281
- def _do_load():
282
  global _active_model_name, _loading
283
  try:
284
- if len(_model_cache) >= _model_cache_max: _model_cache.popitem(last=False)
285
- new_model = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
286
- _model_cache[model_id] = new_model
287
- _model_cache.move_to_end(model_id)
288
  _active_model_name = model_id
289
- _prime_preset_voice_cache(new_model)
290
  finally: _loading = False
291
- async with _generation_lock: await asyncio.to_thread(_do_load)
292
- return {"status": "loaded", "model": model_id}
293
 
294
  @app.post("/generate/stream")
295
  async def generate_stream(
296
- text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
297
- ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
298
- xvec_only: bool = Form(True), chunk_size: int = Form(8), temperature: float = Form(0.9),
299
- top_k: int = Form(50), repetition_penalty: float = Form(1.05),
300
  ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
 
301
  ):
302
- if not _active_model_name or _active_model_name not in _model_cache:
303
- raise HTTPException(status_code=400, detail="Model not loaded.")
304
-
305
  tmp_path = None
306
- tmp_is_cached = False
307
- if ref_preset and ref_preset in _preset_refs:
308
- preset = _preset_refs[ref_preset]
309
- tmp_path, tmp_is_cached = preset["path"], True
310
- if not ref_text: ref_text = preset["ref_text"]
311
- elif ref_audio and ref_audio.filename:
312
- content = await ref_audio.read()
313
- tmp_path, tmp_is_cached = _get_cached_ref_path(content), True
314
 
315
  loop = asyncio.get_event_loop()
316
  queue = asyncio.Queue()
317
 
318
- def run_generation():
319
  try:
320
- model = _model_cache.get(_active_model_name)
321
  t0 = time.perf_counter()
322
  total_audio_s = 0.0
323
 
324
- if mode == "voice_clone":
325
- gen = model.generate_voice_clone_streaming(
326
- text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
327
- xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature,
328
- top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
329
- )
330
- elif mode == "custom":
331
- gen = model.generate_custom_voice_streaming(
332
- text=text, speaker=speaker, language=language, instruct=instruct,
333
- chunk_size=chunk_size, temperature=temperature, top_k=top_k,
334
- repetition_penalty=repetition_penalty, max_new_tokens=360
335
- )
336
- else:
337
- gen = model.generate_voice_design_streaming(
338
- text=text, instruct=instruct, language=language, chunk_size=chunk_size,
339
- temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
340
- )
341
 
342
  ttfa_ms, total_gen_ms = None, 0.0
343
  for chunk, sr, timing in gen:
344
- total_gen_ms += timing.get('prefill_ms', 0) + timing.get('decode_ms', 0)
 
 
 
 
 
345
  if ttfa_ms is None: ttfa_ms = total_gen_ms
346
- chunk_audio = _concat_audio(chunk)
 
347
  total_audio_s += len(chunk_audio) / sr
 
 
348
  rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
 
 
 
 
349
  payload = {
350
- "type": "chunk", "audio_b64": _to_wav_b64(chunk_audio, sr), "sample_rate": sr,
351
- "ttfa_ms": round(ttfa_ms), "rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
352
- "elapsed_ms": round((time.perf_counter() - t0) * 1000, 3)
353
  }
354
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
355
 
@@ -358,75 +216,32 @@ async def generate_stream(
358
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
359
  finally:
360
  loop.call_soon_threadsafe(queue.put_nowait, None)
361
- if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: os.unlink(tmp_path)
362
 
363
  async def sse():
364
- global _generation_waiters
365
- _generation_waiters += 1
366
- try:
367
- async with _generation_lock:
368
- _generation_waiters -= 1
369
- thread = threading.Thread(target=run_generation, daemon=True)
370
- thread.start()
371
- while True:
372
- msg = await queue.get()
373
- if msg is None: break
374
- yield f"data: {msg}\n\n"
375
- finally: pass
376
 
377
  return StreamingResponse(sse(), media_type="text/event-stream")
378
 
379
- @app.post("/generate")
380
- async def generate_non_streaming(
381
- text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
382
- ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
383
- xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50),
384
- repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
385
- ):
386
- model = _model_cache.get(_active_model_name)
387
- if not model: raise HTTPException(status_code=400, detail="Model not loaded.")
388
-
389
- tmp_path = None
390
- if ref_preset and ref_preset in _preset_refs: tmp_path = _preset_refs[ref_preset]["path"]
391
- elif ref_audio: tmp_path = _get_cached_ref_path(await ref_audio.read())
392
-
393
- def run():
394
- t0 = time.perf_counter()
395
- if mode == "voice_clone":
396
- audio_list, sr = model.generate_voice_clone(text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, xvec_only=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
397
- elif mode == "custom":
398
- audio_list, sr = model.generate_custom_voice(text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
399
- else:
400
- audio_list, sr = model.generate_voice_design(text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
401
- elapsed = time.perf_counter() - t0
402
- audio = _concat_audio(audio_list)
403
- return audio, sr, elapsed, len(audio)/sr
404
-
405
- async with _generation_lock:
406
- audio, sr, elapsed, dur = await asyncio.to_thread(run)
407
- return JSONResponse({"audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3)}})
408
-
409
  def main():
410
- parser = argparse.ArgumentParser()
411
- parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base")
412
- parser.add_argument("--port", type=int, default=7860)
413
- parser.add_argument("--host", default="0.0.0.0")
414
- args = parser.parse_args()
415
-
 
 
 
416
  global _active_model_name, _parakeet
417
-
418
- print(f"Loading model: {args.model}")
419
- _startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
420
- _model_cache[args.model] = _startup_model
421
  _active_model_name = args.model
422
- _prime_preset_voice_cache(_startup_model)
423
-
424
- print("Loading transcription model (nano-parakeet)…")
425
  _parakeet = _parakeet_from_pretrained(device="cpu")
426
- print("Transcription model ready on CPU.")
427
 
428
- print(f"Server ready on CPU. Port: {args.port}")
429
- uvicorn.run(app, host=args.host, port=args.port)
430
 
431
  if __name__ == "__main__":
432
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Faster Qwen3-TTS Demo Server (CPU Edición Ultra-Resistente)
4
  """
5
 
6
  import argparse
 
26
  from fastapi.middleware.cors import CORSMiddleware
27
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
28
 
29
+ # OPTIMIZACIÓN CPU
30
  torch.set_num_threads(4)
31
  sys.path.insert(0, str(Path(__file__).parent.parent))
32
 
33
  # ==============================================================================
34
+ # 🛡️ ESCUDO TOTAL ANTI-CUDA Y ANTI-NONE
35
  # ==============================================================================
36
  import site
37
 
38
+ def _apply_shield():
39
+ # 1. Parche físico
40
  try:
41
  for p in site.getsitepackages():
42
  model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
43
  if os.path.exists(model_py):
44
  with open(model_py, "r") as f: code = f.read()
45
+ code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
46
+ with open(model_py, "w") as f: f.write(code)
 
47
  except Exception: pass
48
 
49
+ # 2. Mock de CUDA
 
 
50
  torch.cuda.is_available = lambda: False
51
  torch.cuda.current_device = lambda: 0
52
  torch.cuda.device_count = lambda: 1
53
+ if hasattr(torch.cuda, '_lazy_init'): torch.cuda._lazy_init = lambda *args, **kwargs: None
54
 
55
+ # 3. Forzado de Tensors y Modules a CPU
56
  torch.Tensor.cuda = lambda self, *args, **kwargs: self
57
  torch.nn.Module.cuda = lambda self, *args, **kwargs: self
58
 
59
+ def _mock_to(self, *args, **kwargs):
 
 
60
  new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
61
  if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
62
  kwargs['device'] = 'cpu'
63
+ return _orig_to(self, *new_args, **kwargs)
 
64
 
65
+ _orig_to = torch.Tensor.to
66
+ torch.Tensor.to = _mock_to
67
+ _orig_mod_to = torch.nn.Module.to
68
+ torch.nn.Module.to = _mock_to
 
 
 
69
 
70
+ _apply_shield()
71
 
 
72
  try:
73
  from faster_qwen3_tts import FasterQwen3TTS
74
  import faster_qwen3_tts.model as fq_model
 
76
  class CPU_PredictorGraph:
77
  def __init__(self, model, *args, **kwargs):
78
  self.model = model
79
+ self.device = torch.device("cpu")
80
  def __call__(self, *args, **kwargs): return self.model(*args, **kwargs)
81
  def forward(self, *args, **kwargs): return self.model(*args, **kwargs)
82
  def warmup(self, *args, **kwargs): pass
 
84
 
85
  fq_model.PredictorGraph = CPU_PredictorGraph
86
  except ImportError:
 
87
  sys.exit(1)
88
  # ==============================================================================
89
 
90
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
91
 
92
+ _ALL_MODELS = [
93
  "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
94
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
95
  "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
 
97
  "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
98
  ]
99
 
100
+ # Configuración de modelos activos
101
+ _active_env = os.environ.get("ACTIVE_MODELS", "Qwen/Qwen3-TTS-12Hz-0.6B-Base,Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
102
+ AVAILABLE_MODELS = [m.strip() for m in _active_env.split(",") if m.strip()]
 
 
 
103
 
104
  _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
105
+ PRESET_REFS = [
 
106
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
107
  ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
108
  ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
109
  ]
110
 
111
+ _preset_refs: dict = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ def _get_cached_ref_path(content: bytes) -> str:
114
+ digest = hashlib.sha1(content).hexdigest()
115
+ path = Path(tempfile.gettempdir()) / f"fq3_ref_{digest}.wav"
116
+ if not path.exists(): path.write_bytes(content)
117
+ return str(path)
118
+
119
+ app = FastAPI(title="Faster Qwen3-TTS CPU")
120
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
121
 
122
+ _model_cache: OrderedDict = OrderedDict()
 
123
  _active_model_name: str | None = None
124
  _loading = False
 
 
125
  _parakeet = None
126
  _generation_lock = asyncio.Lock()
127
+ _generation_waiters = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  @app.get("/")
130
  async def root(): return FileResponse(Path(__file__).parent / "index.html")
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  @app.get("/status")
133
  async def get_status():
134
+ active = _model_cache.get(_active_model_name)
135
+ speakers = active.model.get_supported_speakers() if active else []
 
 
 
 
 
 
136
  return {
137
+ "loaded": active is not None, "model": _active_model_name, "loading": _loading,
138
+ "available_models": AVAILABLE_MODELS, "speakers": speakers,
139
+ "preset_refs": [{"id": k, "label": v["label"]} for k,v in _preset_refs.items()]
 
 
 
 
 
 
 
140
  }
141
 
 
 
 
 
 
 
142
  @app.post("/load")
143
  async def load_model(model_id: str = Form(...)):
144
  global _active_model_name, _loading
145
  if model_id in _model_cache:
146
  _active_model_name = model_id
147
+ return {"status": "ok"}
 
148
  _loading = True
149
+ def _do():
150
  global _active_model_name, _loading
151
  try:
152
+ if len(_model_cache) >= 1: _model_cache.popitem(last=False)
153
+ m = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
154
+ _model_cache[model_id] = m
 
155
  _active_model_name = model_id
 
156
  finally: _loading = False
157
+ async with _generation_lock: await asyncio.to_thread(_do)
158
+ return {"status": "loaded"}
159
 
160
  @app.post("/generate/stream")
161
  async def generate_stream(
162
+ text: str = Form(...), mode: str = Form("voice_clone"),
 
 
 
163
  ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
164
+ chunk_size: int = Form(8), temperature: float = Form(0.9)
165
  ):
166
+ model = _model_cache.get(_active_model_name)
167
+ if not model: raise HTTPException(status_code=400, detail="Carga el modelo primero")
168
+
169
  tmp_path = None
170
+ if ref_preset and ref_preset in _preset_refs: tmp_path = _preset_refs[ref_preset]["path"]
171
+ elif ref_audio: tmp_path = _get_cached_ref_path(await ref_audio.read())
 
 
 
 
 
 
172
 
173
  loop = asyncio.get_event_loop()
174
  queue = asyncio.Queue()
175
 
176
+ def run_gen():
177
  try:
 
178
  t0 = time.perf_counter()
179
  total_audio_s = 0.0
180
 
181
+ gen = model.generate_voice_clone_streaming(
182
+ text=text, ref_audio=tmp_path, chunk_size=chunk_size,
183
+ temperature=temperature, max_new_tokens=360
184
+ ) if mode == "voice_clone" else model.generate_voice_design_streaming(
185
+ text=text, chunk_size=chunk_size, temperature=temperature, max_new_tokens=360
186
+ )
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  ttfa_ms, total_gen_ms = None, 0.0
189
  for chunk, sr, timing in gen:
190
+ # 🛡️ PROTECCIÓN ANTI-NONE: Si timing es None o faltan keys, usamos 0
191
+ timing = timing or {}
192
+ prefill = timing.get('prefill_ms') or 0.0
193
+ decode = timing.get('decode_ms') or 0.0
194
+
195
+ total_gen_ms += (float(prefill) + float(decode))
196
  if ttfa_ms is None: ttfa_ms = total_gen_ms
197
+
198
+ chunk_audio = np.concatenate([np.array(a).squeeze() for a in chunk]) if isinstance(chunk, list) else chunk.squeeze()
199
  total_audio_s += len(chunk_audio) / sr
200
+
201
+ # RTF Safe
202
  rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
203
+
204
+ buf = io.BytesIO()
205
+ sf.write(buf, chunk_audio.astype(np.float32), sr, format="WAV", subtype="PCM_16")
206
+
207
  payload = {
208
+ "type": "chunk", "audio_b64": base64.b64encode(buf.getvalue()).decode(),
209
+ "sample_rate": sr, "ttfa_ms": round(ttfa_ms), "rtf": round(rtf, 3),
210
+ "total_audio_s": round(total_audio_s, 3)
211
  }
212
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
213
 
 
216
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
217
  finally:
218
  loop.call_soon_threadsafe(queue.put_nowait, None)
 
219
 
220
  async def sse():
221
+ async with _generation_lock:
222
+ threading.Thread(target=run_gen, daemon=True).start()
223
+ while True:
224
+ msg = await queue.get()
225
+ if msg is None: break
226
+ yield f"data: {msg}\n\n"
 
 
 
 
 
 
227
 
228
  return StreamingResponse(sse(), media_type="text/event-stream")
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def main():
231
+ args = argparse.ArgumentParser()
232
+ args.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base")
233
+ args.add_argument("--port", type=int, default=7860)
234
+ args = args.parse_args()
235
+
236
+ # Carga inicial
237
+ print(f"Iniciando en CPU...")
238
+ m = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
239
+ _model_cache[args.model] = m
240
  global _active_model_name, _parakeet
 
 
 
 
241
  _active_model_name = args.model
 
 
 
242
  _parakeet = _parakeet_from_pretrained(device="cpu")
 
243
 
244
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
 
245
 
246
  if __name__ == "__main__":
247
  main()