Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import yaml | |
| from scenestreamer.model_hub import DEFAULT_HF_REPO | |
| from scenestreamer._startup import configure_startup_noise_filters, quiet_native_startup_noise | |
| configure_startup_noise_filters() | |
| def add_model_args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument("--device", default="auto", help="torch device string: auto, cuda, mps, or cpu") | |
| p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint") | |
| p.add_argument("--hf-repo", default=DEFAULT_HF_REPO, help="HuggingFace repo id, e.g. user/repo") | |
| p.add_argument("--hf-file", default="scenestreamer-full-large.ckpt", help="HuggingFace filename, e.g. model.ckpt") | |
| def load_model_from_args(args: argparse.Namespace): | |
| with quiet_native_startup_noise(): | |
| from scenestreamer.utils import utils | |
| if args.ckpt: | |
| return utils.get_model(checkpoint_path=args.ckpt, device=args.device) | |
| if args.hf_repo: | |
| return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=args.device) | |
| raise SystemExit("Must provide either --hf-repo/--hf-file or --ckpt") | |
| def add_run_args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts") | |
| p.add_argument("--run-id", default=None, help="Run ID (default: autogenerated)") | |
| p.add_argument("--seed", type=int, default=0) | |
| def print_stage(prefix: str, current: int, total: int, message: str) -> None: | |
| print(f"[{prefix}] Stage {current}/{total}: {message}") | |
| def add_config_args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument("--config", default="cfgs/motion_default.yaml", help="Path to YAML config") | |
| p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)") | |
| def load_yaml_config(path: str): | |
| # Keep dependency-light: EasyDict is optional, we can just use attribute dicts from the library. | |
| from easydict import EasyDict | |
| with open(path, "r") as f: | |
| data = yaml.safe_load(f) | |
| def _to_edict(obj: Any): | |
| if isinstance(obj, dict): | |
| return EasyDict({k: _to_edict(v) for k, v in obj.items()}) | |
| if isinstance(obj, list): | |
| return [_to_edict(v) for v in obj] | |
| return obj | |
| return _to_edict(data) | |
| def apply_overrides(cfg, overrides: list[str]) -> None: | |
| """ | |
| Apply overrides of form KEY=VALUE where KEY is dot-delimited. | |
| VALUE is parsed using yaml.safe_load (so numbers/bools/lists work). | |
| """ | |
| for item in overrides: | |
| if "=" not in item: | |
| raise SystemExit(f"Invalid override (expected KEY=VALUE): {item}") | |
| key, raw_val = item.split("=", 1) | |
| value = yaml.safe_load(raw_val) | |
| cur = cfg | |
| parts = key.split(".") | |
| for p in parts[:-1]: | |
| if not hasattr(cur, p): | |
| setattr(cur, p, type(cfg)()) | |
| cur = getattr(cur, p) | |
| setattr(cur, parts[-1], value) | |
| def require_scenarionet() -> None: | |
| try: | |
| with quiet_native_startup_noise(): | |
| import scenarionet # noqa: F401 | |
| except ModuleNotFoundError as e: | |
| raise SystemExit( | |
| "Missing dependency 'scenarionet'. Install it via:\n" | |
| " pip install git+https://github.com/metadriverse/scenarionet.git\n" | |
| ) from e | |