userhugginggit commited on
Commit
fd09ecb
·
verified ·
1 Parent(s): 41b15cc

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +66 -138
server.py CHANGED
@@ -1,11 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- Faster Qwen3-TTS Demo Server (CPU Optimizado + Parches Anti-CUDA y Anti-None)
4
-
5
- Usage:
6
- python demo/server.py
7
- python demo/server.py --model Qwen/Qwen3-TTS-12Hz-1.7B-Base --port 7860
8
- python demo/server.py --no-preload # skip startup model load
9
  """
10
 
11
  import argparse
@@ -35,73 +30,12 @@ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
35
  torch.set_num_threads(4)
36
  sys.path.insert(0, str(Path(__file__).parent.parent))
37
 
38
- # ==============================================================================
39
- # 🛡️ ESCUDO TOTAL ANTI-CUDA Y ANTI-NONE
40
- # ==============================================================================
41
- import site
42
-
43
- def _apply_anti_cuda_shield():
44
- # 1. Parche físico para el ValueError de la librería
45
- try:
46
- for p in site.getsitepackages():
47
- model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
48
- if os.path.exists(model_py):
49
- with open(model_py, "r") as f: code = f.read()
50
- if 'raise ValueError("CUDA graphs require CUDA device")' in code:
51
- code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
52
- with open(model_py, "w") as f: f.write(code)
53
- except Exception: pass
54
-
55
- # 2. Neutralizar validaciones internas de CUDA
56
- if hasattr(torch.cuda, '_lazy_init'):
57
- torch.cuda._lazy_init = lambda *args, **kwargs: None
58
- torch.cuda.is_available = lambda: False
59
- torch.cuda.current_device = lambda: 0
60
- torch.cuda.device_count = lambda: 1
61
- torch.cuda.get_device_name = lambda x: "CPU"
62
-
63
- # 3. Interceptar .cuda()
64
- torch.Tensor.cuda = lambda self, *args, **kwargs: self
65
- torch.nn.Module.cuda = lambda self, *args, **kwargs: self
66
-
67
- # 4. Interceptar y redirigir .to('cuda') hacia .to('cpu')
68
- _orig_tensor_to = torch.Tensor.to
69
- def _tensor_to_mock(self, *args, **kwargs):
70
- new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
71
- if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
72
- kwargs['device'] = 'cpu'
73
- return _orig_tensor_to(self, *new_args, **kwargs)
74
- torch.Tensor.to = _tensor_to_mock
75
-
76
- _orig_module_to = torch.nn.Module.to
77
- def _module_to_mock(self, *args, **kwargs):
78
- new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
79
- if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
80
- kwargs['device'] = 'cpu'
81
- return _orig_module_to(self, *new_args, **kwargs)
82
- torch.nn.Module.to = _module_to_mock
83
-
84
- _apply_anti_cuda_shield()
85
-
86
  try:
87
- from faster_qwen3_tts import FasterQwen3TTS
88
- import faster_qwen3_tts.model as fq_model
89
-
90
- # Clon del PredictorGraph para CPU
91
- class CPU_PredictorGraph:
92
- def __init__(self, model, *args, **kwargs):
93
- self.model = model
94
- self.device = torch.device("cpu")
95
- def __call__(self, *args, **kwargs): return self.model(*args, **kwargs)
96
- def forward(self, *args, **kwargs): return self.model(*args, **kwargs)
97
- def warmup(self, *args, **kwargs): pass
98
- def __getattr__(self, name): return getattr(self.model, name)
99
-
100
- fq_model.PredictorGraph = CPU_PredictorGraph
101
  except ImportError:
102
- print("Error: faster_qwen3_tts not found.")
103
  sys.exit(1)
104
- # ==============================================================================
105
 
106
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
107
 
@@ -116,12 +50,12 @@ _ALL_MODELS =[
116
  _active_models_env = os.environ.get("ACTIVE_MODELS", "")
117
  if _active_models_env:
118
  _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
119
- AVAILABLE_MODELS = [m for m in _ALL_MODELS if m in _allowed]
120
  else:
121
  AVAILABLE_MODELS = list(_ALL_MODELS)
122
 
123
  BASE_DIR = Path(__file__).resolve().parent
124
- _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
125
  PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
126
  PRESET_REFS =[
127
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
@@ -176,20 +110,10 @@ def _load_preset_refs() -> None:
176
  "audio_b64": base64.b64encode(content).decode(),
177
  }
178
 
179
- def _prime_preset_voice_cache(model: FasterQwen3TTS) -> None:
180
- if not _preset_refs: return
181
- for preset in _preset_refs.values():
182
- try:
183
- model._prepare_generation(
184
- text="Hello.", ref_audio=preset["path"], ref_text=preset["ref_text"],
185
- language="English", xvec_only=True, non_streaming_mode=True,
186
- )
187
- except Exception: continue
188
-
189
- app = FastAPI(title="Faster Qwen3-TTS Demo")
190
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
191
 
192
- _model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict()
193
  _model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "1"))
194
  _active_model_name: str | None = None
195
  _loading = False
@@ -220,7 +144,7 @@ def _get_cached_ref_path(content: bytes) -> str:
220
  with _ref_cache_lock:
221
  cached = _ref_cache.get(digest)
222
  if cached and os.path.exists(cached): return cached
223
- path = Path(tempfile.gettempdir()) / f"faster_qwen3_tts_ref_{digest}.wav"
224
  if not path.exists(): path.write_bytes(content)
225
  _ref_cache[digest] = str(path)
226
  return str(path)
@@ -251,8 +175,8 @@ async def get_status():
251
  active = _model_cache.get(_active_model_name) if _active_model_name else None
252
  if active is not None:
253
  try:
254
- model_type = active.model.model.tts_model_type
255
- speakers = active.model.get_supported_speakers() or[]
256
  except Exception: pass
257
  return {
258
  "loaded": active is not None, "model": _active_model_name, "loading": _loading,
@@ -280,11 +204,16 @@ async def load_model(model_id: str = Form(...)):
280
  global _active_model_name, _loading
281
  try:
282
  if len(_model_cache) >= _model_cache_max: _model_cache.popitem(last=False)
283
- new_model = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
 
 
 
 
 
 
284
  _model_cache[model_id] = new_model
285
  _model_cache.move_to_end(model_id)
286
  _active_model_name = model_id
287
- _prime_preset_voice_cache(new_model)
288
  finally: _loading = False
289
  async with _generation_lock: await asyncio.to_thread(_do_load)
290
  return {"status": "loaded", "model": model_id}
@@ -319,59 +248,46 @@ async def generate_stream(
319
  try:
320
  model = _model_cache.get(_active_model_name)
321
  t0 = time.perf_counter()
322
- total_audio_s = 0.0
323
- voice_clone_ms = 0.0
324
 
 
325
  if mode == "voice_clone":
326
- gen = model.generate_voice_clone_streaming(
327
  text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
328
- xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature,
329
- top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
330
  )
331
  elif mode == "custom":
332
- gen = model.generate_custom_voice_streaming(
333
  text=text, speaker=speaker, language=language, instruct=instruct,
334
- chunk_size=chunk_size, temperature=temperature, top_k=top_k,
335
- repetition_penalty=repetition_penalty, max_new_tokens=360
336
  )
337
  else:
338
- gen = model.generate_voice_design_streaming(
339
- text=text, instruct=instruct, language=language, chunk_size=chunk_size,
340
  temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
341
  )
342
 
343
- ttfa_ms, total_gen_ms = None, 0.0
344
-
345
- for chunk, sr, timing in gen:
346
- # 🛡️ PROTECCIÓN ANTI-NONE APLICADA
347
- timing = timing or {}
348
- prefill = timing.get('prefill_ms')
349
- decode = timing.get('decode_ms')
350
-
351
- # Convertimos a float de forma segura (0.0 si es None)
352
- prefill_val = float(prefill) if prefill is not None else 0.0
353
- decode_val = float(decode) if decode is not None else 0.0
354
-
355
- total_gen_ms += (prefill_val + decode_val)
356
- if ttfa_ms is None: ttfa_ms = total_gen_ms
357
-
358
- chunk_audio = _concat_audio(chunk)
359
- total_audio_s += len(chunk_audio) / sr
360
- rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
361
-
362
- payload = {
363
- "type": "chunk", "audio_b64": _to_wav_b64(chunk_audio, sr), "sample_rate": sr,
364
- "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms),
365
- "rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
366
- "elapsed_ms": round((time.perf_counter() - t0) * 1000, 3)
367
- }
368
- loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
369
-
370
- loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
371
- "type": "done", "ttfa_ms": round(ttfa_ms or 0), "voice_clone_ms": round(voice_clone_ms),
372
- "rtf": round(rtf, 3) if 'rtf' in locals() else 0.0,
373
- "total_audio_s": round(total_audio_s, 3), "total_ms": round((time.perf_counter() - t0) * 1000)
374
- }))
375
  except Exception as e:
376
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
377
  finally:
@@ -414,11 +330,18 @@ async def generate_non_streaming(
414
  def run():
415
  t0 = time.perf_counter()
416
  if mode == "voice_clone":
417
- audio_list, sr = model.generate_voice_clone(text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, xvec_only=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
 
 
 
418
  elif mode == "custom":
419
- audio_list, sr = model.generate_custom_voice(text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
 
 
420
  else:
421
- audio_list, sr = model.generate_voice_design(text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
 
 
422
  elapsed = time.perf_counter() - t0
423
  audio = _concat_audio(audio_list)
424
  return audio, sr, elapsed, len(audio)/sr
@@ -429,7 +352,7 @@ async def generate_non_streaming(
429
  return JSONResponse({"audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}})
430
 
431
  def main():
432
- parser = argparse.ArgumentParser(description="Faster Qwen3-TTS Demo Server")
433
  parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup")
434
  parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
435
  parser.add_argument("--host", default="0.0.0.0")
@@ -438,11 +361,16 @@ def main():
438
 
439
  if not args.no_preload:
440
  global _active_model_name, _parakeet
441
- print(f"Loading model: {args.model}")
442
- _startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
 
 
 
 
 
 
443
  _model_cache[args.model] = _startup_model
444
  _active_model_name = args.model
445
- _prime_preset_voice_cache(_startup_model)
446
 
447
  print("Loading transcription model (nano-parakeet)…")
448
  _parakeet = _parakeet_from_pretrained(device="cpu")
 
1
  #!/usr/bin/env python3
2
  """
3
+ Qwen3-TTS Demo Server (Librería Oficial - CPU Nativo)
 
 
 
 
 
4
  """
5
 
6
  import argparse
 
30
  torch.set_num_threads(4)
31
  sys.path.insert(0, str(Path(__file__).parent.parent))
32
 
33
+ # Importamos la librería OFICIAL de Alibaba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
+ from qwen_tts import Qwen3TTSModel
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except ImportError:
37
+ print("Error: qwen-tts no está instalado. Revisa tu requirements.txt")
38
  sys.exit(1)
 
39
 
40
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
41
 
 
50
  _active_models_env = os.environ.get("ACTIVE_MODELS", "")
51
  if _active_models_env:
52
  _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
53
+ AVAILABLE_MODELS =[m for m in _ALL_MODELS if m in _allowed]
54
  else:
55
  AVAILABLE_MODELS = list(_ALL_MODELS)
56
 
57
  BASE_DIR = Path(__file__).resolve().parent
58
+ _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/qwen3-tts-assets"))
59
  PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
60
  PRESET_REFS =[
61
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
 
110
  "audio_b64": base64.b64encode(content).decode(),
111
  }
112
 
113
+ app = FastAPI(title="Qwen3-TTS Demo Oficial")
 
 
 
 
 
 
 
 
 
 
114
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
115
 
116
+ _model_cache: OrderedDict[str, Qwen3TTSModel] = OrderedDict()
117
  _model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "1"))
118
  _active_model_name: str | None = None
119
  _loading = False
 
144
  with _ref_cache_lock:
145
  cached = _ref_cache.get(digest)
146
  if cached and os.path.exists(cached): return cached
147
+ path = Path(tempfile.gettempdir()) / f"qwen3_tts_ref_{digest}.wav"
148
  if not path.exists(): path.write_bytes(content)
149
  _ref_cache[digest] = str(path)
150
  return str(path)
 
175
  active = _model_cache.get(_active_model_name) if _active_model_name else None
176
  if active is not None:
177
  try:
178
+ model_type = "official"
179
+ speakers = active.get_supported_speakers() or[]
180
  except Exception: pass
181
  return {
182
  "loaded": active is not None, "model": _active_model_name, "loading": _loading,
 
204
  global _active_model_name, _loading
205
  try:
206
  if len(_model_cache) >= _model_cache_max: _model_cache.popitem(last=False)
207
+
208
+ # Carga NATIVA de la librería oficial
209
+ new_model = Qwen3TTSModel.from_pretrained(
210
+ model_id,
211
+ device_map="cpu",
212
+ dtype=torch.float32
213
+ )
214
  _model_cache[model_id] = new_model
215
  _model_cache.move_to_end(model_id)
216
  _active_model_name = model_id
 
217
  finally: _loading = False
218
  async with _generation_lock: await asyncio.to_thread(_do_load)
219
  return {"status": "loaded", "model": model_id}
 
248
  try:
249
  model = _model_cache.get(_active_model_name)
250
  t0 = time.perf_counter()
 
 
251
 
252
+ # Generación estándar empaquetada en un solo bloque para evitar crasheos de chunks
253
  if mode == "voice_clone":
254
+ audio_list, sr = model.generate_voice_clone(
255
  text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
256
+ x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k,
257
+ repetition_penalty=repetition_penalty, max_new_tokens=360
258
  )
259
  elif mode == "custom":
260
+ audio_list, sr = model.generate_custom_voice(
261
  text=text, speaker=speaker, language=language, instruct=instruct,
262
+ temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
 
263
  )
264
  else:
265
+ audio_list, sr = model.generate_voice_design(
266
+ text=text, instruct=instruct, language=language,
267
  temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
268
  )
269
 
270
+ elapsed = time.perf_counter() - t0
271
+ chunk_audio = _concat_audio(audio_list)
272
+ dur = len(chunk_audio) / sr
273
+ rtf = dur / elapsed if elapsed > 0 else 0.0
274
+ ttfa_ms = round(elapsed * 1000)
275
+
276
+ # Enviamos el audio completo como un único "Chunk"
277
+ payload = {
278
+ "type": "chunk", "audio_b64": _to_wav_b64(chunk_audio, sr), "sample_rate": sr,
279
+ "ttfa_ms": ttfa_ms, "voice_clone_ms": 0, "rtf": round(rtf, 3),
280
+ "total_audio_s": round(dur, 3), "elapsed_ms": ttfa_ms
281
+ }
282
+ loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
283
+
284
+ # Enviamos señal de "Done"
285
+ done_payload = {
286
+ "type": "done", "ttfa_ms": ttfa_ms, "voice_clone_ms": 0,
287
+ "rtf": round(rtf, 3), "total_audio_s": round(dur, 3), "total_ms": ttfa_ms
288
+ }
289
+ loop.call_soon_threadsafe(queue.put_nowait, json.dumps(done_payload))
290
+
 
 
 
 
 
 
 
 
 
 
 
291
  except Exception as e:
292
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
293
  finally:
 
330
  def run():
331
  t0 = time.perf_counter()
332
  if mode == "voice_clone":
333
+ audio_list, sr = model.generate_voice_clone(
334
+ text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
335
+ x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k,
336
+ repetition_penalty=repetition_penalty, max_new_tokens=360)
337
  elif mode == "custom":
338
+ audio_list, sr = model.generate_custom_voice(
339
+ text=text, speaker=speaker, language=language, instruct=instruct,
340
+ temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
341
  else:
342
+ audio_list, sr = model.generate_voice_design(
343
+ text=text, instruct=instruct, language=language, temperature=temperature,
344
+ top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
345
  elapsed = time.perf_counter() - t0
346
  audio = _concat_audio(audio_list)
347
  return audio, sr, elapsed, len(audio)/sr
 
352
  return JSONResponse({"audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}})
353
 
354
  def main():
355
+ parser = argparse.ArgumentParser(description="Qwen3-TTS Demo Server")
356
  parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup")
357
  parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
358
  parser.add_argument("--host", default="0.0.0.0")
 
361
 
362
  if not args.no_preload:
363
  global _active_model_name, _parakeet
364
+ print(f"Loading official model: {args.model}")
365
+
366
+ _startup_model = Qwen3TTSModel.from_pretrained(
367
+ args.model,
368
+ device_map="cpu",
369
+ dtype=torch.float32
370
+ )
371
+
372
  _model_cache[args.model] = _startup_model
373
  _active_model_name = args.model
 
374
 
375
  print("Loading transcription model (nano-parakeet)…")
376
  _parakeet = _parakeet_from_pretrained(device="cpu")