Zhyw commited on
Commit
a2e0662
·
verified ·
1 Parent(s): 09f1ddb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -24
app.py CHANGED
@@ -55,8 +55,8 @@ WARMUP_BASE_ASSISTANT_TEXT = (
55
  def _apply_seed(seed: int | None) -> None:
56
  if seed is None:
57
  return
 
58
  torch.manual_seed(seed)
59
- torch.cuda.manual_seed_all(seed)
60
 
61
 
62
  def _load_audio(path: Path, target_sample_rate: int = SAMPLE_RATE) -> torch.Tensor:
@@ -553,14 +553,13 @@ def _load_backend(
553
  device_str: str,
554
  attn_impl: str,
555
  ):
556
-
557
  device = torch.device(device_str)
558
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
559
  processor = MossTTSRealtimeProcessor(tokenizer)
560
 
561
- # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
562
  dtype = torch.float16
563
-
564
  if attn_impl and attn_impl.lower() not in {"none", ""}:
565
  model = MossTTSRealtime.from_pretrained(model_path, attn_implementation=attn_impl, torch_dtype=dtype).to(device)
566
  if (
@@ -798,10 +797,11 @@ class WarmupManager:
798
  self._lock = threading.Lock()
799
  self._thread: threading.Thread | None = None
800
  self._started = False
801
- self._state = "pending"
802
- self._progress = 0.0
803
- self._message = "Waiting for startup warmup."
804
- self._detail = "The app warms the streaming path before the first real request."
 
805
  self._error: str | None = None
806
 
807
  def start(self) -> None:
@@ -1288,8 +1288,7 @@ def _build_demo(
1288
  tts_demo: StreamingTTSDemo,
1289
  warmup_manager: WarmupManager,
1290
  ):
1291
- # initial_warmup_snapshot = warmup_manager.snapshot()
1292
-
1293
  with gr.Blocks(title="MossTTSRealtime") as demo:
1294
  gr.Markdown("MossTTSRealtime demo")
1295
  gr.Markdown("Note: The first run may take a while to load the model.")
@@ -1324,15 +1323,10 @@ def _build_demo(
1324
  chunk_duration = gr.Slider(0.01, 1.0, value=0.24, step=0.01, label="Codec Chunk Duration (s)")
1325
  stream_prebuffer_seconds = gr.Slider(0.0, 20.0, value=0.0, step=0.05, label="Initial Buffer (s)")
1326
 
1327
- # run_btn = gr.Button(
1328
- # "Generate" if initial_warmup_snapshot.ready else "Warming Up...",
1329
- # elem_id="tts_generate",
1330
- # interactive=initial_warmup_snapshot.ready,
1331
- # )
1332
  run_btn = gr.Button(
1333
- "Generate",
1334
  elem_id="tts_generate",
1335
- interactive=True,
1336
  )
1337
 
1338
  with gr.Column():
@@ -1341,7 +1335,7 @@ def _build_demo(
1341
  initial_status = _status_from_snapshot(initial_warmup_snapshot)
1342
  status = gr.Textbox(label="Status", lines=3, value=initial_status)
1343
 
1344
- warmup_timer = gr.Timer(value=WARMUP_POLL_INTERVAL_SECONDS, active=True)
1345
 
1346
  def _poll_warmup_state():
1347
  snapshot = warmup_manager.snapshot()
@@ -1350,7 +1344,7 @@ def _build_demo(
1350
  _warmup_status_update(snapshot),
1351
  _warmup_timer_update(snapshot),
1352
  )
1353
-
1354
  @spaces.GPU
1355
  def _on_generate(
1356
  user_text_value,
@@ -1374,10 +1368,6 @@ def _build_demo(
1374
  chunk_duration_value,
1375
  stream_prebuffer_seconds_value,
1376
  ):
1377
- # warmup_snapshot = warmup_manager.snapshot()
1378
- # if not warmup_snapshot.ready:
1379
- # yield json.dumps({"reset": True}), gr.update(value=None), _warmup_gate_message(warmup_snapshot)
1380
- # return
1381
  try:
1382
  started_at = time.monotonic()
1383
  full_chunks: list[np.ndarray] = []
@@ -1530,6 +1520,8 @@ def main():
1530
  attn_impl=args.attn_implementation,
1531
  ),
1532
  )
 
 
1533
  # warmup_manager.start()
1534
  demo = _build_demo(args, tts_demo, warmup_manager)
1535
  demo.queue(max_size=10, default_concurrency_limit=1).launch(
@@ -1540,4 +1532,4 @@ def main():
1540
 
1541
 
1542
  if __name__ == "__main__":
1543
- main()
 
55
  def _apply_seed(seed: int | None) -> None:
56
  if seed is None:
57
  return
58
+ # ZeroGPU: avoid touching torch.cuda outside the managed GPU call.
59
  torch.manual_seed(seed)
 
60
 
61
 
62
  def _load_audio(path: Path, target_sample_rate: int = SAMPLE_RATE) -> torch.Tensor:
 
553
  device_str: str,
554
  attn_impl: str,
555
  ):
556
+ # ZeroGPU: do not call torch.cuda.is_available() here; it may trigger low-level CUDA init.
557
  device = torch.device(device_str)
558
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
559
  processor = MossTTSRealtimeProcessor(tokenizer)
560
 
561
+ # ZeroGPU: avoid torch.cuda.is_bf16_supported() before CUDA is fully managed.
562
  dtype = torch.float16
 
563
  if attn_impl and attn_impl.lower() not in {"none", ""}:
564
  model = MossTTSRealtime.from_pretrained(model_path, attn_implementation=attn_impl, torch_dtype=dtype).to(device)
565
  if (
 
797
  self._lock = threading.Lock()
798
  self._thread: threading.Thread | None = None
799
  self._started = False
800
+ # ZeroGPU: startup warmup is disabled because it initializes CUDA outside @spaces.GPU.
801
+ self._state = "ready"
802
+ self._progress = 1.0
803
+ self._message = "Ready."
804
+ self._detail = "Startup warmup disabled for ZeroGPU; the first generation will load the model."
805
  self._error: str | None = None
806
 
807
  def start(self) -> None:
 
1288
  tts_demo: StreamingTTSDemo,
1289
  warmup_manager: WarmupManager,
1290
  ):
1291
+ initial_warmup_snapshot = warmup_manager.snapshot()
 
1292
  with gr.Blocks(title="MossTTSRealtime") as demo:
1293
  gr.Markdown("MossTTSRealtime demo")
1294
  gr.Markdown("Note: The first run may take a while to load the model.")
 
1323
  chunk_duration = gr.Slider(0.01, 1.0, value=0.24, step=0.01, label="Codec Chunk Duration (s)")
1324
  stream_prebuffer_seconds = gr.Slider(0.0, 20.0, value=0.0, step=0.05, label="Initial Buffer (s)")
1325
 
 
 
 
 
 
1326
  run_btn = gr.Button(
1327
+ "Generate" if initial_warmup_snapshot.ready else "Warming Up...",
1328
  elem_id="tts_generate",
1329
+ interactive=initial_warmup_snapshot.ready,
1330
  )
1331
 
1332
  with gr.Column():
 
1335
  initial_status = _status_from_snapshot(initial_warmup_snapshot)
1336
  status = gr.Textbox(label="Status", lines=3, value=initial_status)
1337
 
1338
+ warmup_timer = gr.Timer(value=WARMUP_POLL_INTERVAL_SECONDS, active=not initial_warmup_snapshot.ready)
1339
 
1340
  def _poll_warmup_state():
1341
  snapshot = warmup_manager.snapshot()
 
1344
  _warmup_status_update(snapshot),
1345
  _warmup_timer_update(snapshot),
1346
  )
1347
+
1348
  @spaces.GPU
1349
  def _on_generate(
1350
  user_text_value,
 
1368
  chunk_duration_value,
1369
  stream_prebuffer_seconds_value,
1370
  ):
 
 
 
 
1371
  try:
1372
  started_at = time.monotonic()
1373
  full_chunks: list[np.ndarray] = []
 
1520
  attn_impl=args.attn_implementation,
1521
  ),
1522
  )
1523
+ # ZeroGPU: do not run startup warmup, because it would initialize CUDA
1524
+ # in a background thread outside @spaces.GPU.
1525
  # warmup_manager.start()
1526
  demo = _build_demo(args, tts_demo, warmup_manager)
1527
  demo.queue(max_size=10, default_concurrency_limit=1).launch(
 
1532
 
1533
 
1534
  if __name__ == "__main__":
1535
+ main()