SceneStreamer / scripts /_common.py
pengzhenghao97's picture
Update Gradio checkpoint defaults
e573e0e verified
from __future__ import annotations
import argparse
from dataclasses import dataclass
from typing import Any
import yaml
from scenestreamer.model_hub import DEFAULT_HF_REPO
from scenestreamer._startup import configure_startup_noise_filters, quiet_native_startup_noise
configure_startup_noise_filters()
def add_model_args(p: argparse.ArgumentParser) -> None:
p.add_argument("--device", default="auto", help="torch device string: auto, cuda, mps, or cpu")
p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint")
p.add_argument("--hf-repo", default=DEFAULT_HF_REPO, help="HuggingFace repo id, e.g. user/repo")
p.add_argument("--hf-file", default="scenestreamer-full-large.ckpt", help="HuggingFace filename, e.g. model.ckpt")
def load_model_from_args(args: argparse.Namespace):
with quiet_native_startup_noise():
from scenestreamer.utils import utils
if args.ckpt:
return utils.get_model(checkpoint_path=args.ckpt, device=args.device)
if args.hf_repo:
return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=args.device)
raise SystemExit("Must provide either --hf-repo/--hf-file or --ckpt")
def add_run_args(p: argparse.ArgumentParser) -> None:
p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts")
p.add_argument("--run-id", default=None, help="Run ID (default: autogenerated)")
p.add_argument("--seed", type=int, default=0)
def print_stage(prefix: str, current: int, total: int, message: str) -> None:
print(f"[{prefix}] Stage {current}/{total}: {message}")
def add_config_args(p: argparse.ArgumentParser) -> None:
p.add_argument("--config", default="cfgs/motion_default.yaml", help="Path to YAML config")
p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)")
def load_yaml_config(path: str):
# Keep dependency-light: EasyDict is optional, we can just use attribute dicts from the library.
from easydict import EasyDict
with open(path, "r") as f:
data = yaml.safe_load(f)
def _to_edict(obj: Any):
if isinstance(obj, dict):
return EasyDict({k: _to_edict(v) for k, v in obj.items()})
if isinstance(obj, list):
return [_to_edict(v) for v in obj]
return obj
return _to_edict(data)
def apply_overrides(cfg, overrides: list[str]) -> None:
"""
Apply overrides of form KEY=VALUE where KEY is dot-delimited.
VALUE is parsed using yaml.safe_load (so numbers/bools/lists work).
"""
for item in overrides:
if "=" not in item:
raise SystemExit(f"Invalid override (expected KEY=VALUE): {item}")
key, raw_val = item.split("=", 1)
value = yaml.safe_load(raw_val)
cur = cfg
parts = key.split(".")
for p in parts[:-1]:
if not hasattr(cur, p):
setattr(cur, p, type(cfg)())
cur = getattr(cur, p)
setattr(cur, parts[-1], value)
def require_scenarionet() -> None:
try:
with quiet_native_startup_noise():
import scenarionet # noqa: F401
except ModuleNotFoundError as e:
raise SystemExit(
"Missing dependency 'scenarionet'. Install it via:\n"
" pip install git+https://github.com/metadriverse/scenarionet.git\n"
) from e