gaoyang07 commited on
Commit
c5b84ea
·
1 Parent(s): ccad48d

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -24
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import functools
3
  import importlib.util
 
4
  from pathlib import Path
5
  import re
6
  import time
@@ -11,6 +12,19 @@ import numpy as np
11
  import torch
12
  from transformers import AutoModel, AutoProcessor
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Disable the broken cuDNN SDPA backend
15
  torch.backends.cuda.enable_cudnn_sdp(False)
16
  # Keep these enabled as fallbacks
@@ -21,6 +35,7 @@ torch.backends.cuda.enable_math_sdp(True)
21
  MODEL_PATH = "OpenMOSS-Team/MOSS-TTS"
22
  DEFAULT_ATTN_IMPLEMENTATION = "auto"
23
  DEFAULT_MAX_NEW_TOKENS = 4096
 
24
  CONTINUATION_NOTICE = (
25
  "Continuation mode is active. Make sure the reference audio transcript is prepended to the input text."
26
  )
@@ -289,6 +304,7 @@ def apply_example_selection(
289
  )
290
 
291
 
 
292
  def run_inference(
293
  text: str,
294
  reference_audio: str | None,
@@ -574,48 +590,102 @@ def build_demo(args: argparse.Namespace):
574
  return demo
575
 
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def main():
578
  parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
579
  parser.add_argument("--model_path", type=str, default=MODEL_PATH)
580
  parser.add_argument("--device", type=str, default="cuda:0")
581
  parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
582
  parser.add_argument("--host", type=str, default="0.0.0.0")
583
- parser.add_argument("--port", type=int, default=7860)
 
 
 
 
584
  parser.add_argument("--share", action="store_true")
585
  args = parser.parse_args()
586
 
587
- runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu")
588
- runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32
589
- args.attn_implementation = resolve_attn_implementation(
590
- requested=args.attn_implementation,
591
- device=runtime_device,
592
- dtype=runtime_dtype,
593
- ) or "none"
594
  print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True)
595
 
596
- # Preload model/processor at startup to avoid first-request cold start latency.
597
- preload_started_at = time.monotonic()
598
- print(
599
- f"[Startup] Preloading backend: model={args.model_path}, device={args.device}, attn={args.attn_implementation}",
600
- flush=True,
601
- )
602
- load_backend(
603
- model_path=args.model_path,
604
- device_str=args.device,
605
- attn_implementation=args.attn_implementation,
606
- )
607
- print(
608
- f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s",
609
- flush=True,
610
- )
 
 
 
 
 
 
611
 
612
  demo = build_demo(args)
613
  demo.queue(max_size=16, default_concurrency_limit=1).launch(
614
  server_name=args.host,
615
  server_port=args.port,
616
  share=args.share,
 
617
  )
618
 
619
 
 
 
 
 
620
  if __name__ == "__main__":
621
- main()
 
 
 
 
1
  import argparse
2
  import functools
3
  import importlib.util
4
+ import os
5
  from pathlib import Path
6
  import re
7
  import time
 
12
  import torch
13
  from transformers import AutoModel, AutoProcessor
14
 
15
+ try:
16
+ import spaces
17
+ except ImportError:
18
+ class _SpacesFallback:
19
+ @staticmethod
20
+ def GPU(*_args, **_kwargs):
21
+ def _decorator(func):
22
+ return func
23
+
24
+ return _decorator
25
+
26
+ spaces = _SpacesFallback()
27
+
28
  # Disable the broken cuDNN SDPA backend
29
  torch.backends.cuda.enable_cudnn_sdp(False)
30
  # Keep these enabled as fallbacks
 
35
  MODEL_PATH = "OpenMOSS-Team/MOSS-TTS"
36
  DEFAULT_ATTN_IMPLEMENTATION = "auto"
37
  DEFAULT_MAX_NEW_TOKENS = 4096
38
+ PRELOAD_ENV_VAR = "MOSS_TTS_PRELOAD_AT_STARTUP"
39
  CONTINUATION_NOTICE = (
40
  "Continuation mode is active. Make sure the reference audio transcript is prepended to the input text."
41
  )
 
304
  )
305
 
306
 
307
+ @spaces.GPU(duration=180)
308
  def run_inference(
309
  text: str,
310
  reference_audio: str | None,
 
590
  return demo
591
 
592
 
593
+ def resolve_runtime_attn(args: argparse.Namespace) -> argparse.Namespace:
594
+ runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu")
595
+ runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32
596
+ args.attn_implementation = resolve_attn_implementation(
597
+ requested=args.attn_implementation,
598
+ device=runtime_device,
599
+ dtype=runtime_dtype,
600
+ ) or "none"
601
+ return args
602
+
603
+
604
+ def parse_bool_env(name: str, default: bool) -> bool:
605
+ value = os.getenv(name)
606
+ if value is None:
607
+ return default
608
+ return value.strip().lower() in {"1", "true", "yes", "y", "on"}
609
+
610
+
611
+ def parse_port(value: str | None, default: int) -> int:
612
+ if not value:
613
+ return default
614
+ try:
615
+ return int(value)
616
+ except ValueError:
617
+ return default
618
+
619
+
620
+ def build_default_args() -> argparse.Namespace:
621
+ return resolve_runtime_attn(
622
+ argparse.Namespace(
623
+ model_path=MODEL_PATH,
624
+ device="cuda:0",
625
+ attn_implementation=DEFAULT_ATTN_IMPLEMENTATION,
626
+ host=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
627
+ port=parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), 7860),
628
+ share=False,
629
+ )
630
+ )
631
+
632
+
633
  def main():
634
  parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
635
  parser.add_argument("--model_path", type=str, default=MODEL_PATH)
636
  parser.add_argument("--device", type=str, default="cuda:0")
637
  parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
638
  parser.add_argument("--host", type=str, default="0.0.0.0")
639
+ parser.add_argument(
640
+ "--port",
641
+ type=int,
642
+ default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
643
+ )
644
  parser.add_argument("--share", action="store_true")
645
  args = parser.parse_args()
646
 
647
+ args.host = os.getenv("GRADIO_SERVER_NAME", args.host)
648
+ args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port)
649
+ args = resolve_runtime_attn(args)
 
 
 
 
650
  print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True)
651
 
652
+ preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID")))
653
+ if preload_enabled:
654
+ preload_started_at = time.monotonic()
655
+ print(
656
+ f"[Startup] Preloading backend: model={args.model_path}, device={args.device}, attn={args.attn_implementation}",
657
+ flush=True,
658
+ )
659
+ load_backend(
660
+ model_path=args.model_path,
661
+ device_str=args.device,
662
+ attn_implementation=args.attn_implementation,
663
+ )
664
+ print(
665
+ f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s",
666
+ flush=True,
667
+ )
668
+ else:
669
+ print(
670
+ f"[Startup] Skipping preload (set {PRELOAD_ENV_VAR}=1 to enable).",
671
+ flush=True,
672
+ )
673
 
674
  demo = build_demo(args)
675
  demo.queue(max_size=16, default_concurrency_limit=1).launch(
676
  server_name=args.host,
677
  server_port=args.port,
678
  share=args.share,
679
+ ssr_mode=False,
680
  )
681
 
682
 
683
+ # Expose a module-level demo for Gradio hot-reload/Spaces launcher.
684
+ demo = build_demo(build_default_args())
685
+
686
+
687
  if __name__ == "__main__":
688
+ if os.getenv("GRADIO_HOT_RELOAD"):
689
+ print("[Startup] GRADIO_HOT_RELOAD detected. Skipping explicit launch().", flush=True)
690
+ else:
691
+ main()