File size: 3,415 Bytes
e573e0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

import argparse
from dataclasses import dataclass
from typing import Any

import yaml
from scenestreamer.model_hub import DEFAULT_HF_REPO
from scenestreamer._startup import configure_startup_noise_filters, quiet_native_startup_noise

configure_startup_noise_filters()


def add_model_args(p: argparse.ArgumentParser) -> None:
    p.add_argument("--device", default="auto", help="torch device string: auto, cuda, mps, or cpu")
    p.add_argument("--ckpt", default=None, help="Path to a .ckpt checkpoint")
    p.add_argument("--hf-repo", default=DEFAULT_HF_REPO, help="HuggingFace repo id, e.g. user/repo")
    p.add_argument("--hf-file", default="scenestreamer-full-large.ckpt", help="HuggingFace filename, e.g. model.ckpt")


def load_model_from_args(args: argparse.Namespace):
    with quiet_native_startup_noise():
        from scenestreamer.utils import utils

    if args.ckpt:
        return utils.get_model(checkpoint_path=args.ckpt, device=args.device)
    if args.hf_repo:
        return utils.get_model(huggingface_repo=args.hf_repo, huggingface_file=args.hf_file, device=args.device)
    raise SystemExit("Must provide either --hf-repo/--hf-file or --ckpt")


def add_run_args(p: argparse.ArgumentParser) -> None:
    p.add_argument("--artifacts-dir", default="artifacts", help="Directory to write run artifacts")
    p.add_argument("--run-id", default=None, help="Run ID (default: autogenerated)")
    p.add_argument("--seed", type=int, default=0)


def print_stage(prefix: str, current: int, total: int, message: str) -> None:
    print(f"[{prefix}] Stage {current}/{total}: {message}")


def add_config_args(p: argparse.ArgumentParser) -> None:
    p.add_argument("--config", default="cfgs/motion_default.yaml", help="Path to YAML config")
    p.add_argument("--set", action="append", default=[], help="Override config KEY=VALUE (repeatable)")


def load_yaml_config(path: str):
    # Keep dependency-light: EasyDict is optional, we can just use attribute dicts from the library.
    from easydict import EasyDict

    with open(path, "r") as f:
        data = yaml.safe_load(f)

    def _to_edict(obj: Any):
        if isinstance(obj, dict):
            return EasyDict({k: _to_edict(v) for k, v in obj.items()})
        if isinstance(obj, list):
            return [_to_edict(v) for v in obj]
        return obj

    return _to_edict(data)


def apply_overrides(cfg, overrides: list[str]) -> None:
    """
    Apply overrides of form KEY=VALUE where KEY is dot-delimited.
    VALUE is parsed using yaml.safe_load (so numbers/bools/lists work).
    """
    for item in overrides:
        if "=" not in item:
            raise SystemExit(f"Invalid override (expected KEY=VALUE): {item}")
        key, raw_val = item.split("=", 1)
        value = yaml.safe_load(raw_val)

        cur = cfg
        parts = key.split(".")
        for p in parts[:-1]:
            if not hasattr(cur, p):
                setattr(cur, p, type(cfg)())
            cur = getattr(cur, p)
        setattr(cur, parts[-1], value)


def require_scenarionet() -> None:
    try:
        with quiet_native_startup_noise():
            import scenarionet  # noqa: F401
    except ModuleNotFoundError as e:
        raise SystemExit(
            "Missing dependency 'scenarionet'. Install it via:\n"
            "  pip install git+https://github.com/metadriverse/scenarionet.git\n"
        ) from e