gaoyang07 commited on
Commit ·
c5b84ea
1
Parent(s): ccad48d
fix app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import argparse
|
| 2 |
import functools
|
| 3 |
import importlib.util
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
import re
|
| 6 |
import time
|
|
@@ -11,6 +12,19 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
from transformers import AutoModel, AutoProcessor
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Disable the broken cuDNN SDPA backend
|
| 15 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 16 |
# Keep these enabled as fallbacks
|
|
@@ -21,6 +35,7 @@ torch.backends.cuda.enable_math_sdp(True)
|
|
| 21 |
MODEL_PATH = "OpenMOSS-Team/MOSS-TTS"
|
| 22 |
DEFAULT_ATTN_IMPLEMENTATION = "auto"
|
| 23 |
DEFAULT_MAX_NEW_TOKENS = 4096
|
|
|
|
| 24 |
CONTINUATION_NOTICE = (
|
| 25 |
"Continuation mode is active. Make sure the reference audio transcript is prepended to the input text."
|
| 26 |
)
|
|
@@ -289,6 +304,7 @@ def apply_example_selection(
|
|
| 289 |
)
|
| 290 |
|
| 291 |
|
|
|
|
| 292 |
def run_inference(
|
| 293 |
text: str,
|
| 294 |
reference_audio: str | None,
|
|
@@ -574,48 +590,102 @@ def build_demo(args: argparse.Namespace):
|
|
| 574 |
return demo
|
| 575 |
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
def main():
|
| 578 |
parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
|
| 579 |
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
|
| 580 |
parser.add_argument("--device", type=str, default="cuda:0")
|
| 581 |
parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
|
| 582 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 583 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
parser.add_argument("--share", action="store_true")
|
| 585 |
args = parser.parse_args()
|
| 586 |
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
args
|
| 590 |
-
requested=args.attn_implementation,
|
| 591 |
-
device=runtime_device,
|
| 592 |
-
dtype=runtime_dtype,
|
| 593 |
-
) or "none"
|
| 594 |
print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True)
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
demo = build_demo(args)
|
| 613 |
demo.queue(max_size=16, default_concurrency_limit=1).launch(
|
| 614 |
server_name=args.host,
|
| 615 |
server_port=args.port,
|
| 616 |
share=args.share,
|
|
|
|
| 617 |
)
|
| 618 |
|
| 619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
if __name__ == "__main__":
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import functools
|
| 3 |
import importlib.util
|
| 4 |
+
import os
|
| 5 |
from pathlib import Path
|
| 6 |
import re
|
| 7 |
import time
|
|
|
|
| 12 |
import torch
|
| 13 |
from transformers import AutoModel, AutoProcessor
|
| 14 |
|
| 15 |
+
try:
|
| 16 |
+
import spaces
|
| 17 |
+
except ImportError:
|
| 18 |
+
class _SpacesFallback:
|
| 19 |
+
@staticmethod
|
| 20 |
+
def GPU(*_args, **_kwargs):
|
| 21 |
+
def _decorator(func):
|
| 22 |
+
return func
|
| 23 |
+
|
| 24 |
+
return _decorator
|
| 25 |
+
|
| 26 |
+
spaces = _SpacesFallback()
|
| 27 |
+
|
| 28 |
# Disable the broken cuDNN SDPA backend
|
| 29 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 30 |
# Keep these enabled as fallbacks
|
|
|
|
| 35 |
MODEL_PATH = "OpenMOSS-Team/MOSS-TTS"
|
| 36 |
DEFAULT_ATTN_IMPLEMENTATION = "auto"
|
| 37 |
DEFAULT_MAX_NEW_TOKENS = 4096
|
| 38 |
+
PRELOAD_ENV_VAR = "MOSS_TTS_PRELOAD_AT_STARTUP"
|
| 39 |
CONTINUATION_NOTICE = (
|
| 40 |
"Continuation mode is active. Make sure the reference audio transcript is prepended to the input text."
|
| 41 |
)
|
|
|
|
| 304 |
)
|
| 305 |
|
| 306 |
|
| 307 |
+
@spaces.GPU(duration=180)
|
| 308 |
def run_inference(
|
| 309 |
text: str,
|
| 310 |
reference_audio: str | None,
|
|
|
|
| 590 |
return demo
|
| 591 |
|
| 592 |
|
| 593 |
+
def resolve_runtime_attn(args: argparse.Namespace) -> argparse.Namespace:
|
| 594 |
+
runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 595 |
+
runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32
|
| 596 |
+
args.attn_implementation = resolve_attn_implementation(
|
| 597 |
+
requested=args.attn_implementation,
|
| 598 |
+
device=runtime_device,
|
| 599 |
+
dtype=runtime_dtype,
|
| 600 |
+
) or "none"
|
| 601 |
+
return args
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def parse_bool_env(name: str, default: bool) -> bool:
|
| 605 |
+
value = os.getenv(name)
|
| 606 |
+
if value is None:
|
| 607 |
+
return default
|
| 608 |
+
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def parse_port(value: str | None, default: int) -> int:
|
| 612 |
+
if not value:
|
| 613 |
+
return default
|
| 614 |
+
try:
|
| 615 |
+
return int(value)
|
| 616 |
+
except ValueError:
|
| 617 |
+
return default
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def build_default_args() -> argparse.Namespace:
|
| 621 |
+
return resolve_runtime_attn(
|
| 622 |
+
argparse.Namespace(
|
| 623 |
+
model_path=MODEL_PATH,
|
| 624 |
+
device="cuda:0",
|
| 625 |
+
attn_implementation=DEFAULT_ATTN_IMPLEMENTATION,
|
| 626 |
+
host=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
|
| 627 |
+
port=parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), 7860),
|
| 628 |
+
share=False,
|
| 629 |
+
)
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
def main():
|
| 634 |
parser = argparse.ArgumentParser(description="MossTTS Gradio Demo")
|
| 635 |
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
|
| 636 |
parser.add_argument("--device", type=str, default="cuda:0")
|
| 637 |
parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
|
| 638 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 639 |
+
parser.add_argument(
|
| 640 |
+
"--port",
|
| 641 |
+
type=int,
|
| 642 |
+
default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
|
| 643 |
+
)
|
| 644 |
parser.add_argument("--share", action="store_true")
|
| 645 |
args = parser.parse_args()
|
| 646 |
|
| 647 |
+
args.host = os.getenv("GRADIO_SERVER_NAME", args.host)
|
| 648 |
+
args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port)
|
| 649 |
+
args = resolve_runtime_attn(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True)
|
| 651 |
|
| 652 |
+
preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID")))
|
| 653 |
+
if preload_enabled:
|
| 654 |
+
preload_started_at = time.monotonic()
|
| 655 |
+
print(
|
| 656 |
+
f"[Startup] Preloading backend: model={args.model_path}, device={args.device}, attn={args.attn_implementation}",
|
| 657 |
+
flush=True,
|
| 658 |
+
)
|
| 659 |
+
load_backend(
|
| 660 |
+
model_path=args.model_path,
|
| 661 |
+
device_str=args.device,
|
| 662 |
+
attn_implementation=args.attn_implementation,
|
| 663 |
+
)
|
| 664 |
+
print(
|
| 665 |
+
f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s",
|
| 666 |
+
flush=True,
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
print(
|
| 670 |
+
f"[Startup] Skipping preload (set {PRELOAD_ENV_VAR}=1 to enable).",
|
| 671 |
+
flush=True,
|
| 672 |
+
)
|
| 673 |
|
| 674 |
demo = build_demo(args)
|
| 675 |
demo.queue(max_size=16, default_concurrency_limit=1).launch(
|
| 676 |
server_name=args.host,
|
| 677 |
server_port=args.port,
|
| 678 |
share=args.share,
|
| 679 |
+
ssr_mode=False,
|
| 680 |
)
|
| 681 |
|
| 682 |
|
| 683 |
+
# Expose a module-level demo for Gradio hot-reload/Spaces launcher.
|
| 684 |
+
demo = build_demo(build_default_args())
|
| 685 |
+
|
| 686 |
+
|
| 687 |
if __name__ == "__main__":
|
| 688 |
+
if os.getenv("GRADIO_HOT_RELOAD"):
|
| 689 |
+
print("[Startup] GRADIO_HOT_RELOAD detected. Skipping explicit launch().", flush=True)
|
| 690 |
+
else:
|
| 691 |
+
main()
|