userhugginggit commited on
Commit
41b15cc
·
verified ·
1 Parent(s): 16dfc68

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +293 -85
server.py CHANGED
@@ -1,6 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- Faster Qwen3-TTS Demo Server (CPU Edición Ultra-Resistente)
 
 
 
 
 
4
  """
5
 
6
  import argparse
@@ -35,44 +40,54 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
35
  # ==============================================================================
36
  import site
37
 
38
- def _apply_shield():
39
- # 1. Parche físico
40
  try:
41
  for p in site.getsitepackages():
42
  model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
43
  if os.path.exists(model_py):
44
  with open(model_py, "r") as f: code = f.read()
45
- code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
46
- with open(model_py, "w") as f: f.write(code)
 
47
  except Exception: pass
48
 
49
- # 2. Mock de CUDA
 
 
50
  torch.cuda.is_available = lambda: False
51
  torch.cuda.current_device = lambda: 0
52
  torch.cuda.device_count = lambda: 1
53
- if hasattr(torch.cuda, '_lazy_init'): torch.cuda._lazy_init = lambda *args, **kwargs: None
54
 
55
- # 3. Forzado de Tensors y Modules a CPU
56
  torch.Tensor.cuda = lambda self, *args, **kwargs: self
57
  torch.nn.Module.cuda = lambda self, *args, **kwargs: self
58
 
59
- def _mock_to(self, *args, **kwargs):
 
 
60
  new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
61
  if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
62
  kwargs['device'] = 'cpu'
63
- return _orig_to(self, *new_args, **kwargs)
 
64
 
65
- _orig_to = torch.Tensor.to
66
- torch.Tensor.to = _mock_to
67
- _orig_mod_to = torch.nn.Module.to
68
- torch.nn.Module.to = _mock_to
 
 
 
69
 
70
- _apply_shield()
71
 
72
  try:
73
  from faster_qwen3_tts import FasterQwen3TTS
74
  import faster_qwen3_tts.model as fq_model
75
 
 
76
  class CPU_PredictorGraph:
77
  def __init__(self, model, *args, **kwargs):
78
  self.model = model
@@ -84,12 +99,13 @@ try:
84
 
85
  fq_model.PredictorGraph = CPU_PredictorGraph
86
  except ImportError:
 
87
  sys.exit(1)
88
  # ==============================================================================
89
 
90
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
91
 
92
- _ALL_MODELS = [
93
  "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
94
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
95
  "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
@@ -97,151 +113,343 @@ _ALL_MODELS = [
97
  "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
98
  ]
99
 
100
- # Configuración de modelos activos
101
- _active_env = os.environ.get("ACTIVE_MODELS", "Qwen/Qwen3-TTS-12Hz-0.6B-Base,Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
102
- AVAILABLE_MODELS = [m.strip() for m in _active_env.split(",") if m.strip()]
 
 
 
103
 
 
104
  _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
105
- PRESET_REFS = [
 
106
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
107
  ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
108
  ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
109
  ]
110
 
111
- _preset_refs: dict = {}
112
-
113
- def _get_cached_ref_path(content: bytes) -> str:
114
- digest = hashlib.sha1(content).hexdigest()
115
- path = Path(tempfile.gettempdir()) / f"fq3_ref_{digest}.wav"
116
- if not path.exists(): path.write_bytes(content)
117
- return str(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- app = FastAPI(title="Faster Qwen3-TTS CPU")
120
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
121
 
122
- _model_cache: OrderedDict = OrderedDict()
 
123
  _active_model_name: str | None = None
124
  _loading = False
 
 
125
  _parakeet = None
126
  _generation_lock = asyncio.Lock()
127
- _generation_waiters = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  @app.get("/")
130
  async def root(): return FileResponse(Path(__file__).parent / "index.html")
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  @app.get("/status")
133
  async def get_status():
134
- active = _model_cache.get(_active_model_name)
135
- speakers = active.model.get_supported_speakers() if active else []
 
 
 
 
 
 
136
  return {
137
  "loaded": active is not None, "model": _active_model_name, "loading": _loading,
138
- "available_models": AVAILABLE_MODELS, "speakers": speakers,
139
- "preset_refs": [{"id": k, "label": v["label"]} for k,v in _preset_refs.items()]
 
 
140
  }
141
 
 
 
 
 
 
 
142
  @app.post("/load")
143
  async def load_model(model_id: str = Form(...)):
144
  global _active_model_name, _loading
145
  if model_id in _model_cache:
146
  _active_model_name = model_id
147
- return {"status": "ok"}
 
148
  _loading = True
149
- def _do():
150
  global _active_model_name, _loading
151
  try:
152
- if len(_model_cache) >= 1: _model_cache.popitem(last=False)
153
- m = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
154
- _model_cache[model_id] = m
 
155
  _active_model_name = model_id
 
156
  finally: _loading = False
157
- async with _generation_lock: await asyncio.to_thread(_do)
158
- return {"status": "loaded"}
159
 
160
  @app.post("/generate/stream")
161
  async def generate_stream(
162
- text: str = Form(...), mode: str = Form("voice_clone"),
 
 
 
163
  ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
164
- chunk_size: int = Form(8), temperature: float = Form(0.9)
165
  ):
166
- model = _model_cache.get(_active_model_name)
167
- if not model: raise HTTPException(status_code=400, detail="Carga el modelo primero")
 
168
 
169
  tmp_path = None
170
- if ref_preset and ref_preset in _preset_refs: tmp_path = _preset_refs[ref_preset]["path"]
171
- elif ref_audio: tmp_path = _get_cached_ref_path(await ref_audio.read())
 
 
 
 
 
 
 
172
 
173
  loop = asyncio.get_event_loop()
174
  queue = asyncio.Queue()
175
 
176
- def run_gen():
177
  try:
 
178
  t0 = time.perf_counter()
179
  total_audio_s = 0.0
 
180
 
181
- gen = model.generate_voice_clone_streaming(
182
- text=text, ref_audio=tmp_path, chunk_size=chunk_size,
183
- temperature=temperature, max_new_tokens=360
184
- ) if mode == "voice_clone" else model.generate_voice_design_streaming(
185
- text=text, chunk_size=chunk_size, temperature=temperature, max_new_tokens=360
186
- )
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  ttfa_ms, total_gen_ms = None, 0.0
 
189
  for chunk, sr, timing in gen:
190
- # 🛡️ PROTECCIÓN ANTI-NONE: Si timing es None o faltan keys, usamos 0
191
  timing = timing or {}
192
- prefill = timing.get('prefill_ms') or 0.0
193
- decode = timing.get('decode_ms') or 0.0
194
 
195
- total_gen_ms += (float(prefill) + float(decode))
 
 
 
 
196
  if ttfa_ms is None: ttfa_ms = total_gen_ms
197
 
198
- chunk_audio = np.concatenate([np.array(a).squeeze() for a in chunk]) if isinstance(chunk, list) else chunk.squeeze()
199
  total_audio_s += len(chunk_audio) / sr
200
-
201
- # RTF Safe
202
  rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
203
 
204
- buf = io.BytesIO()
205
- sf.write(buf, chunk_audio.astype(np.float32), sr, format="WAV", subtype="PCM_16")
206
-
207
  payload = {
208
- "type": "chunk", "audio_b64": base64.b64encode(buf.getvalue()).decode(),
209
- "sample_rate": sr, "ttfa_ms": round(ttfa_ms), "rtf": round(rtf, 3),
210
- "total_audio_s": round(total_audio_s, 3)
 
211
  }
212
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
213
 
214
- loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "done", "ttfa_ms": round(ttfa_ms or 0)}))
 
 
 
 
215
  except Exception as e:
216
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
217
  finally:
218
  loop.call_soon_threadsafe(queue.put_nowait, None)
 
219
 
220
  async def sse():
221
- async with _generation_lock:
222
- threading.Thread(target=run_gen, daemon=True).start()
 
 
 
 
 
 
223
  while True:
224
  msg = await queue.get()
225
  if msg is None: break
226
  yield f"data: {msg}\n\n"
 
 
 
227
 
228
- return StreamingResponse(sse(), media_type="text/event-stream")
229
 
230
- def main():
231
- args = argparse.ArgumentParser()
232
- args.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base")
233
- args.add_argument("--port", type=int, default=7860)
234
- args = args.parse_args()
235
-
236
- # Carga inicial
237
- print(f"Iniciando en CPU...")
238
- m = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
239
- _model_cache[args.model] = m
240
- global _active_model_name, _parakeet
241
- _active_model_name = args.model
242
- _parakeet = _parakeet_from_pretrained(device="cpu")
243
 
244
- uvicorn.run(app, host="0.0.0.0", port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  if __name__ == "__main__":
247
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Faster Qwen3-TTS Demo Server (CPU Optimizado + Parches Anti-CUDA y Anti-None)
4
+
5
+ Usage:
6
+ python demo/server.py
7
+ python demo/server.py --model Qwen/Qwen3-TTS-12Hz-1.7B-Base --port 7860
8
+ python demo/server.py --no-preload # skip startup model load
9
  """
10
 
11
  import argparse
 
40
  # ==============================================================================
41
  import site
42
 
43
+ def _apply_anti_cuda_shield():
44
+ # 1. Parche físico para el ValueError de la librería
45
  try:
46
  for p in site.getsitepackages():
47
  model_py = os.path.join(p, "faster_qwen3_tts", "model.py")
48
  if os.path.exists(model_py):
49
  with open(model_py, "r") as f: code = f.read()
50
+ if 'raise ValueError("CUDA graphs require CUDA device")' in code:
51
+ code = code.replace('raise ValueError("CUDA graphs require CUDA device")', 'pass')
52
+ with open(model_py, "w") as f: f.write(code)
53
  except Exception: pass
54
 
55
+ # 2. Neutralizar validaciones internas de CUDA
56
+ if hasattr(torch.cuda, '_lazy_init'):
57
+ torch.cuda._lazy_init = lambda *args, **kwargs: None
58
  torch.cuda.is_available = lambda: False
59
  torch.cuda.current_device = lambda: 0
60
  torch.cuda.device_count = lambda: 1
61
+ torch.cuda.get_device_name = lambda x: "CPU"
62
 
63
+ # 3. Interceptar .cuda()
64
  torch.Tensor.cuda = lambda self, *args, **kwargs: self
65
  torch.nn.Module.cuda = lambda self, *args, **kwargs: self
66
 
67
+ # 4. Interceptar y redirigir .to('cuda') hacia .to('cpu')
68
+ _orig_tensor_to = torch.Tensor.to
69
+ def _tensor_to_mock(self, *args, **kwargs):
70
  new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
71
  if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
72
  kwargs['device'] = 'cpu'
73
+ return _orig_tensor_to(self, *new_args, **kwargs)
74
+ torch.Tensor.to = _tensor_to_mock
75
 
76
+ _orig_module_to = torch.nn.Module.to
77
+ def _module_to_mock(self, *args, **kwargs):
78
+ new_args = tuple('cpu' if isinstance(a, str) and 'cuda' in a else a for a in args)
79
+ if 'device' in kwargs and isinstance(kwargs['device'], str) and 'cuda' in kwargs['device']:
80
+ kwargs['device'] = 'cpu'
81
+ return _orig_module_to(self, *new_args, **kwargs)
82
+ torch.nn.Module.to = _module_to_mock
83
 
84
+ _apply_anti_cuda_shield()
85
 
86
  try:
87
  from faster_qwen3_tts import FasterQwen3TTS
88
  import faster_qwen3_tts.model as fq_model
89
 
90
+ # Clon del PredictorGraph para CPU
91
  class CPU_PredictorGraph:
92
  def __init__(self, model, *args, **kwargs):
93
  self.model = model
 
99
 
100
  fq_model.PredictorGraph = CPU_PredictorGraph
101
  except ImportError:
102
+ print("Error: faster_qwen3_tts not found.")
103
  sys.exit(1)
104
  # ==============================================================================
105
 
106
  from nano_parakeet import from_pretrained as _parakeet_from_pretrained
107
 
108
+ _ALL_MODELS =[
109
  "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
110
  "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
111
  "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
 
113
  "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
114
  ]
115
 
116
+ _active_models_env = os.environ.get("ACTIVE_MODELS", "")
117
+ if _active_models_env:
118
+ _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
119
+ AVAILABLE_MODELS = [m for m in _ALL_MODELS if m in _allowed]
120
+ else:
121
+ AVAILABLE_MODELS = list(_ALL_MODELS)
122
 
123
+ BASE_DIR = Path(__file__).resolve().parent
124
  _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/faster-qwen3-tts-assets"))
125
+ PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
126
+ PRESET_REFS =[
127
  ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
128
  ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
129
  ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
130
  ]
131
 
132
+ _GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main"
133
+ _PRESET_REMOTE = {
134
+ "ref_audio": f"{_GITHUB_RAW}/ref_audio.wav",
135
+ "ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav",
136
+ "ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav",
137
+ }
138
+ _TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt"
139
+
140
+ def _fetch_preset_assets() -> None:
141
+ import urllib.request
142
+ _ASSET_DIR.mkdir(parents=True, exist_ok=True)
143
+ PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True)
144
+ if not PRESET_TRANSCRIPTS.exists():
145
+ try: urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS)
146
+ except Exception: pass
147
+ for key, path, _ in PRESET_REFS:
148
+ if not path.exists() and key in _PRESET_REMOTE:
149
+ try: urllib.request.urlretrieve(_PRESET_REMOTE[key], path)
150
+ except Exception: pass
151
+
152
+ _preset_refs: dict[str, dict] = {}
153
+
154
+ def _load_preset_transcripts() -> dict[str, str]:
155
+ if not PRESET_TRANSCRIPTS.exists(): return {}
156
+ transcripts = {}
157
+ for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines():
158
+ if ":" not in line: continue
159
+ key_part, text = line.split(":", 1)
160
+ key = key_part.split("(")[0].strip()
161
+ transcripts[key] = text.strip()
162
+ return transcripts
163
+
164
+ def _load_preset_refs() -> None:
165
+ transcripts = _load_preset_transcripts()
166
+ for key, path, label in PRESET_REFS:
167
+ if not path.exists(): continue
168
+ content = path.read_bytes()
169
+ cached_path = _get_cached_ref_path(content)
170
+ _preset_refs[key] = {
171
+ "id": key,
172
+ "label": label,
173
+ "filename": path.name,
174
+ "path": cached_path,
175
+ "ref_text": transcripts.get(key, ""),
176
+ "audio_b64": base64.b64encode(content).decode(),
177
+ }
178
+
179
+ def _prime_preset_voice_cache(model: FasterQwen3TTS) -> None:
180
+ if not _preset_refs: return
181
+ for preset in _preset_refs.values():
182
+ try:
183
+ model._prepare_generation(
184
+ text="Hello.", ref_audio=preset["path"], ref_text=preset["ref_text"],
185
+ language="English", xvec_only=True, non_streaming_mode=True,
186
+ )
187
+ except Exception: continue
188
 
189
+ app = FastAPI(title="Faster Qwen3-TTS Demo")
190
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
191
 
192
+ _model_cache: OrderedDict[str, FasterQwen3TTS] = OrderedDict()
193
+ _model_cache_max: int = int(os.environ.get("MODEL_CACHE_SIZE", "1"))
194
  _active_model_name: str | None = None
195
  _loading = False
196
+ _ref_cache: dict[str, str] = {}
197
+ _ref_cache_lock = threading.Lock()
198
  _parakeet = None
199
  _generation_lock = asyncio.Lock()
200
+ _generation_waiters: int = 0
201
+
202
+ MAX_TEXT_CHARS = 1000
203
+ MAX_AUDIO_BYTES = 10 * 1024 * 1024
204
+ _AUDIO_TOO_LARGE_MSG = "Audio file too large. Please upload a shorter recording."
205
+
206
+ def _to_wav_b64(audio: np.ndarray, sr: int) -> str:
207
+ if audio.dtype != np.float32: audio = audio.astype(np.float32)
208
+ if audio.ndim > 1: audio = audio.squeeze()
209
+ buf = io.BytesIO()
210
+ sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
211
+ return base64.b64encode(buf.getvalue()).decode()
212
+
213
+ def _concat_audio(audio_list) -> np.ndarray:
214
+ if isinstance(audio_list, np.ndarray): return audio_list.astype(np.float32).squeeze()
215
+ parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0]
216
+ return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32)
217
+
218
+ def _get_cached_ref_path(content: bytes) -> str:
219
+ digest = hashlib.sha1(content).hexdigest()
220
+ with _ref_cache_lock:
221
+ cached = _ref_cache.get(digest)
222
+ if cached and os.path.exists(cached): return cached
223
+ path = Path(tempfile.gettempdir()) / f"faster_qwen3_tts_ref_{digest}.wav"
224
+ if not path.exists(): path.write_bytes(content)
225
+ _ref_cache[digest] = str(path)
226
+ return str(path)
227
+
228
+ _fetch_preset_assets()
229
+ _load_preset_refs()
230
 
231
  @app.get("/")
232
  async def root(): return FileResponse(Path(__file__).parent / "index.html")
233
 
234
+ @app.post("/transcribe")
235
+ async def transcribe_audio(audio: UploadFile = File(...)):
236
+ if _parakeet is None: raise HTTPException(status_code=503, detail="Transcription model not loaded")
237
+ content = await audio.read()
238
+ if len(content) > MAX_AUDIO_BYTES: raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG)
239
+ def run():
240
+ wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False)
241
+ if wav.ndim > 1: wav = wav.mean(axis=1)
242
+ wav_t = torch.from_numpy(wav)
243
+ if sr != 16000: wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0)
244
+ return _parakeet.transcribe(wav_t)
245
+ return {"text": await asyncio.to_thread(run)}
246
+
247
  @app.get("/status")
248
  async def get_status():
249
+ speakers =[]
250
+ model_type = None
251
+ active = _model_cache.get(_active_model_name) if _active_model_name else None
252
+ if active is not None:
253
+ try:
254
+ model_type = active.model.model.tts_model_type
255
+ speakers = active.model.get_supported_speakers() or[]
256
+ except Exception: pass
257
  return {
258
  "loaded": active is not None, "model": _active_model_name, "loading": _loading,
259
+ "available_models": AVAILABLE_MODELS, "model_type": model_type, "speakers": speakers,
260
+ "transcription_available": _parakeet is not None,
261
+ "preset_refs": [{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} for p in _preset_refs.values()],
262
+ "queue_depth": _generation_waiters, "cached_models": list(_model_cache.keys()),
263
  }
264
 
265
+ @app.get("/preset_ref/{preset_id}")
266
+ async def get_preset_ref(preset_id: str):
267
+ preset = _preset_refs.get(preset_id)
268
+ if not preset: raise HTTPException(status_code=404, detail="Preset not found")
269
+ return preset
270
+
271
  @app.post("/load")
272
  async def load_model(model_id: str = Form(...)):
273
  global _active_model_name, _loading
274
  if model_id in _model_cache:
275
  _active_model_name = model_id
276
+ _model_cache.move_to_end(model_id)
277
+ return {"status": "already_loaded", "model": model_id}
278
  _loading = True
279
+ def _do_load():
280
  global _active_model_name, _loading
281
  try:
282
+ if len(_model_cache) >= _model_cache_max: _model_cache.popitem(last=False)
283
+ new_model = FasterQwen3TTS.from_pretrained(model_id, device="cpu", dtype=torch.float32)
284
+ _model_cache[model_id] = new_model
285
+ _model_cache.move_to_end(model_id)
286
  _active_model_name = model_id
287
+ _prime_preset_voice_cache(new_model)
288
  finally: _loading = False
289
+ async with _generation_lock: await asyncio.to_thread(_do_load)
290
+ return {"status": "loaded", "model": model_id}
291
 
292
  @app.post("/generate/stream")
293
  async def generate_stream(
294
+ text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
295
+ ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
296
+ xvec_only: bool = Form(True), chunk_size: int = Form(8), temperature: float = Form(0.9),
297
+ top_k: int = Form(50), repetition_penalty: float = Form(1.05),
298
  ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
 
299
  ):
300
+ if not _active_model_name or _active_model_name not in _model_cache:
301
+ raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.")
302
+ if len(text) > MAX_TEXT_CHARS: raise HTTPException(status_code=400, detail="Text too long.")
303
 
304
  tmp_path = None
305
+ tmp_is_cached = False
306
+ if ref_preset and ref_preset in _preset_refs:
307
+ preset = _preset_refs[ref_preset]
308
+ tmp_path, tmp_is_cached = preset["path"], True
309
+ if not ref_text: ref_text = preset["ref_text"]
310
+ elif ref_audio and ref_audio.filename:
311
+ content = await ref_audio.read()
312
+ if len(content) > MAX_AUDIO_BYTES: raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG)
313
+ tmp_path, tmp_is_cached = _get_cached_ref_path(content), True
314
 
315
  loop = asyncio.get_event_loop()
316
  queue = asyncio.Queue()
317
 
318
+ def run_generation():
319
  try:
320
+ model = _model_cache.get(_active_model_name)
321
  t0 = time.perf_counter()
322
  total_audio_s = 0.0
323
+ voice_clone_ms = 0.0
324
 
325
+ if mode == "voice_clone":
326
+ gen = model.generate_voice_clone_streaming(
327
+ text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
328
+ xvec_only=xvec_only, chunk_size=chunk_size, temperature=temperature,
329
+ top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
330
+ )
331
+ elif mode == "custom":
332
+ gen = model.generate_custom_voice_streaming(
333
+ text=text, speaker=speaker, language=language, instruct=instruct,
334
+ chunk_size=chunk_size, temperature=temperature, top_k=top_k,
335
+ repetition_penalty=repetition_penalty, max_new_tokens=360
336
+ )
337
+ else:
338
+ gen = model.generate_voice_design_streaming(
339
+ text=text, instruct=instruct, language=language, chunk_size=chunk_size,
340
+ temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360
341
+ )
342
 
343
  ttfa_ms, total_gen_ms = None, 0.0
344
+
345
  for chunk, sr, timing in gen:
346
+ # 🛡️ PROTECCIÓN ANTI-NONE APLICADA
347
  timing = timing or {}
348
+ prefill = timing.get('prefill_ms')
349
+ decode = timing.get('decode_ms')
350
 
351
+ # Convertimos a float de forma segura (0.0 si es None)
352
+ prefill_val = float(prefill) if prefill is not None else 0.0
353
+ decode_val = float(decode) if decode is not None else 0.0
354
+
355
+ total_gen_ms += (prefill_val + decode_val)
356
  if ttfa_ms is None: ttfa_ms = total_gen_ms
357
 
358
+ chunk_audio = _concat_audio(chunk)
359
  total_audio_s += len(chunk_audio) / sr
 
 
360
  rtf = total_audio_s / (total_gen_ms / 1000) if total_gen_ms > 0 else 0.0
361
 
 
 
 
362
  payload = {
363
+ "type": "chunk", "audio_b64": _to_wav_b64(chunk_audio, sr), "sample_rate": sr,
364
+ "ttfa_ms": round(ttfa_ms), "voice_clone_ms": round(voice_clone_ms),
365
+ "rtf": round(rtf, 3), "total_audio_s": round(total_audio_s, 3),
366
+ "elapsed_ms": round((time.perf_counter() - t0) * 1000, 3)
367
  }
368
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
369
 
370
+ loop.call_soon_threadsafe(queue.put_nowait, json.dumps({
371
+ "type": "done", "ttfa_ms": round(ttfa_ms or 0), "voice_clone_ms": round(voice_clone_ms),
372
+ "rtf": round(rtf, 3) if 'rtf' in locals() else 0.0,
373
+ "total_audio_s": round(total_audio_s, 3), "total_ms": round((time.perf_counter() - t0) * 1000)
374
+ }))
375
  except Exception as e:
376
  loop.call_soon_threadsafe(queue.put_nowait, json.dumps({"type": "error", "message": str(e)}))
377
  finally:
378
  loop.call_soon_threadsafe(queue.put_nowait, None)
379
+ if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: os.unlink(tmp_path)
380
 
381
  async def sse():
382
+ global _generation_waiters
383
+ _generation_waiters += 1
384
+ lock_acquired = False
385
+ try:
386
+ await _generation_lock.acquire()
387
+ lock_acquired = True
388
+ _generation_waiters -= 1
389
+ threading.Thread(target=run_generation, daemon=True).start()
390
  while True:
391
  msg = await queue.get()
392
  if msg is None: break
393
  yield f"data: {msg}\n\n"
394
+ finally:
395
+ if lock_acquired: _generation_lock.release()
396
+ else: _generation_waiters -= 1
397
 
398
+ return StreamingResponse(sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
399
 
400
+ @app.post("/generate")
401
+ async def generate_non_streaming(
402
+ text: str = Form(...), language: str = Form("English"), mode: str = Form("voice_clone"),
403
+ ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
404
+ xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50),
405
+ repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
406
+ ):
407
+ model = _model_cache.get(_active_model_name)
408
+ if not model: raise HTTPException(status_code=400, detail="Model not loaded.")
 
 
 
 
409
 
410
+ tmp_path = None
411
+ if ref_preset and ref_preset in _preset_refs: tmp_path = _preset_refs[ref_preset]["path"]
412
+ elif ref_audio: tmp_path = _get_cached_ref_path(await ref_audio.read())
413
+
414
+ def run():
415
+ t0 = time.perf_counter()
416
+ if mode == "voice_clone":
417
+ audio_list, sr = model.generate_voice_clone(text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, xvec_only=xvec_only, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
418
+ elif mode == "custom":
419
+ audio_list, sr = model.generate_custom_voice(text=text, speaker=speaker, language=language, instruct=instruct, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
420
+ else:
421
+ audio_list, sr = model.generate_voice_design(text=text, instruct=instruct, language=language, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=360)
422
+ elapsed = time.perf_counter() - t0
423
+ audio = _concat_audio(audio_list)
424
+ return audio, sr, elapsed, len(audio)/sr
425
+
426
+ async with _generation_lock:
427
+ audio, sr, elapsed, dur = await asyncio.to_thread(run)
428
+ rtf = dur / elapsed if elapsed > 0 else 0.0
429
+ return JSONResponse({"audio_b64": _to_wav_b64(audio, sr), "sample_rate": sr, "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}})
430
+
431
+ def main():
432
+ parser = argparse.ArgumentParser(description="Faster Qwen3-TTS Demo Server")
433
+ parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup")
434
+ parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
435
+ parser.add_argument("--host", default="0.0.0.0")
436
+ parser.add_argument("--no-preload", action="store_true", help="Skip model loading at startup")
437
+ args = parser.parse_args()
438
+
439
+ if not args.no_preload:
440
+ global _active_model_name, _parakeet
441
+ print(f"Loading model: {args.model}")
442
+ _startup_model = FasterQwen3TTS.from_pretrained(args.model, device="cpu", dtype=torch.float32)
443
+ _model_cache[args.model] = _startup_model
444
+ _active_model_name = args.model
445
+ _prime_preset_voice_cache(_startup_model)
446
+
447
+ print("Loading transcription model (nano-parakeet)…")
448
+ _parakeet = _parakeet_from_pretrained(device="cpu")
449
+ print("Transcription model ready on CPU.")
450
+ print(f"Ready. Open http://localhost:{args.port}")
451
+
452
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
453
 
454
  if __name__ == "__main__":
455
  main()