from __future__ import annotations import argparse from dataclasses import dataclass from typing import Any import yaml from scenestreamer.model_hub import DEFAULT_HF_REPO from scenestreamer._startup import configure_startup_noise_filters, quiet_native_startup_noise configure_startup_noise_filters() def add_model_args(p: argparse.ArgumentParser) -> None: p.add_argument("--device", default="auto", help="torch device string: auto, cuda, mps, or cpu") p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint") p.add_argument("--hf-repo", default=DEFAULT_HF_REPO, help="HuggingFace repo id, e.g. user/repo") p.add_argument("--hf-file", default="scenestreamer-full-large.ckpt", help="HuggingFace filename, e.g. model.ckpt") def load_model_from_args(args: argparse.Namespace): with quiet_native_startup_noise(): from scenestreamer.utils import utils if args.ckpt: return utils.get_model(checkpoint_path=args.ckpt, device=args.device) if args.hf_repo: return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=args.device) raise SystemExit("Must provide either --hf-repo/--hf-file or --ckpt") def add_run_args(p: argparse.ArgumentParser) -> None: p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts") p.add_argument("--run-id", default=None, help="Run ID (default: autogenerated)") p.add_argument("--seed", type=int, default=0) def print_stage(prefix: str, current: int, total: int, message: str) -> None: print(f"[{prefix}] Stage {current}/{total}: {message}") def add_config_args(p: argparse.ArgumentParser) -> None: p.add_argument("--config", default="cfgs/motion_default.yaml", help="Path to YAML config") p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)") def load_yaml_config(path: str): # Keep dependency-light: EasyDict is optional, we can just use attribute dicts from the library. from easydict import EasyDict with open(path, "r") as f: data = yaml.safe_load(f) def _to_edict(obj: Any): if isinstance(obj, dict): return EasyDict({k: _to_edict(v) for k, v in obj.items()}) if isinstance(obj, list): return [_to_edict(v) for v in obj] return obj return _to_edict(data) def apply_overrides(cfg, overrides: list[str]) -> None: """ Apply overrides of form KEY=VALUE where KEY is dot-delimited. VALUE is parsed using yaml.safe_load (so numbers/bools/lists work). """ for item in overrides: if "=" not in item: raise SystemExit(f"Invalid override (expected KEY=VALUE): {item}") key, raw_val = item.split("=", 1) value = yaml.safe_load(raw_val) cur = cfg parts = key.split(".") for p in parts[:-1]: if not hasattr(cur, p): setattr(cur, p, type(cfg)()) cur = getattr(cur, p) setattr(cur, parts[-1], value) def require_scenarionet() -> None: try: with quiet_native_startup_noise(): import scenarionet # noqa: F401 except ModuleNotFoundError as e: raise SystemExit( "Missing dependency 'scenarionet'. Install it via:\n" " pip install git+https://github.com/metadriverse/scenarionet.git\n" ) from e