| """ |
| Shared configuration / utility helpers for framework components: |
| - NamespaceWithGet: lightweight namespace behaving like a dict |
| - OmegaConf conversion helpers |
| - Config merging decorator for model __init__ |
| - Checkpoint config/statistics loading |
| """ |
|
|
| import os |
| from pathlib import Path |
| from types import SimpleNamespace |
| import json |
|
|
| from typing import Union, List |
| import torchvision |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| from types import SimpleNamespace |
| import torch, json |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| import re |
| from omegaconf import OmegaConf |
| from types import SimpleNamespace |
| import inspect |
| import functools |
| from typing import Any |
|
|
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
|
|
| |
| overwatch = initialize_overwatch(__name__) |
|
|
|
|
| from types import SimpleNamespace |
|
|
|
|
| class NamespaceWithGet(SimpleNamespace): |
| def get(self, key, default=None): |
| """ |
| Return attribute value if present, else default (dict-like API). |
| |
| Args: |
| key: Attribute name. |
| default: Fallback if attribute missing. |
| |
| Returns: |
| Any: Stored value or default. |
| """ |
| return getattr(self, key, default) |
|
|
| def items(self): |
| """ |
| Iterate (key, value) pairs like dict.items(). |
| |
| Returns: |
| Generator[Tuple[str, Any], None, None] |
| """ |
| return ((key, getattr(self, key)) for key in self.__dict__) |
|
|
| def __iter__(self): |
| """ |
| Return iterator over attribute keys (enables dict unpacking **obj). |
| |
| Returns: |
| Iterator[str] |
| """ |
| return iter(self.__dict__) |
|
|
| def to_dict(self): |
| """ |
| Recursively convert nested NamespaceWithGet objects into plain dicts. |
| |
| Returns: |
| dict: Fully materialized dictionary structure. |
| """ |
| return {key: value.to_dict() if isinstance(value, NamespaceWithGet) else value for key, value in self.items()} |
|
|
|
|
| def dict_to_namespace(d): |
| """ |
| Create an OmegaConf config from a plain dictionary. |
| |
| Args: |
| d: Input dictionary. |
| |
| Returns: |
| OmegaConf: DictConfig instance. |
| """ |
| return OmegaConf.create(d) |
|
|
|
|
| def _to_omegaconf(x: Any): |
| """ |
| Convert diverse input types into an OmegaConf object. |
| |
| Accepted types: |
| - None -> empty DictConfig |
| - str path -> load YAML/JSON via OmegaConf.load |
| - dict -> DictConfig |
| - DictConfig / ListConfig -> returned unchanged |
| - NamespaceWithGet / SimpleNamespace -> converted via vars()/to_dict() |
| |
| Args: |
| x: Input candidate. |
| |
| Returns: |
| OmegaConf: Normalized configuration node. |
| """ |
| if x is None: |
| return OmegaConf.create({}) |
| if isinstance(x, OmegaConf.__class__): |
| return x |
| try: |
| |
| from omegaconf import DictConfig, ListConfig |
|
|
| if isinstance(x, (DictConfig, ListConfig)): |
| return x |
| except Exception: |
| pass |
|
|
| if isinstance(x, str): |
| |
| return OmegaConf.load(x) |
| if isinstance(x, dict): |
| return OmegaConf.create(x) |
| if isinstance(x, NamespaceWithGet) or isinstance(x, SimpleNamespace): |
| |
| try: |
| d = x.to_dict() if hasattr(x, "to_dict") else vars(x) |
| except Exception: |
| d = vars(x) |
| return OmegaConf.create(d) |
| |
| return OmegaConf.create(x) |
|
|
|
|
| def merge_pram_config(init): |
| """ |
| Decorator for __init__ to unify config handling. |
| |
| Behavior: |
| 1. Extract 'config' kwarg / arg (path | dict | OmegaConf | namespace) |
| 2. Convert to OmegaConf |
| 3. Merge with explicitly passed init parameters (explicit overrides file) |
| 4. Attach merged config to self.config |
| 5. Call original __init__ with merged config |
| |
| Args: |
| init: Original __init__ function. |
| |
| Returns: |
| Wrapped initializer. |
| """ |
|
|
| @functools.wraps(init) |
| def wrapper(self, *args, **kwargs): |
| |
| sig = inspect.signature(init) |
| param_names = [name for i, (name, p) in enumerate(sig.parameters.items()) if i > 0] |
|
|
| init_kwargs = {} |
| for name, val in zip(param_names, args): |
| init_kwargs[name] = val |
| |
| init_kwargs.update(kwargs) |
|
|
| |
| provided_config = init_kwargs.get("config", None) |
|
|
| loaded_cfg = _to_omegaconf(provided_config) |
|
|
| |
| params = {k: v for k, v in init_kwargs.items() if k != "config"} |
| params_cfg = OmegaConf.create(params) if params else OmegaConf.create({}) |
|
|
| |
| merged = OmegaConf.merge(loaded_cfg, params_cfg) |
|
|
| |
| try: |
| |
| self.config = merged |
| except Exception: |
| |
| self.config = OmegaConf.to_container(merged, resolve=True) |
|
|
| |
| call_kwargs = dict(init_kwargs) |
| call_kwargs["config"] = merged |
|
|
| |
| return init(self, **call_kwargs) |
|
|
| return wrapper |
|
|
|
|
| _SUPPORTED_CKPT_SUFFIXES = {".pt", ".safetensors"} |
|
|
|
|
| def _resolve_run_dir(checkpoint_pt): |
| """Resolve run directory from a checkpoint path. |
| |
| For files in checkpoints/ (e.g. run_dir/checkpoints/model.pt), go up 2 levels. |
| """ |
| return checkpoint_pt.parents[1] |
|
|
|
|
| def read_model_config(pretrained_checkpoint): |
| """ |
| Load global model configuration and dataset normalization statistics |
| associated with a saved checkpoint (.pt or .safetensors). |
| |
| Expected directory layout: |
| <run_dir>/checkpoints/<name>.pt|.safetensors |
| <run_dir>/config.json |
| <run_dir>/dataset_statistics.json |
| |
| Args: |
| pretrained_checkpoint: Path to a .pt or .safetensors checkpoint file. |
| |
| Returns: |
| tuple: |
| global_cfg (dict): Loaded config.json contents. |
| norm_stats (dict): Dataset statistics for (de)normalization. |
| |
| Raises: |
| FileNotFoundError: If checkpoint or required JSON files are missing. |
| AssertionError: If file suffix or structure invalid. |
| """ |
| if os.path.isfile(pretrained_checkpoint): |
| checkpoint_pt = Path(pretrained_checkpoint) |
| overwatch.info(f"Loading from local checkpoint path `{checkpoint_pt}`") |
|
|
| assert checkpoint_pt.suffix in _SUPPORTED_CKPT_SUFFIXES, ( |
| f"Unsupported checkpoint suffix `{checkpoint_pt.suffix}`, expected one of {_SUPPORTED_CKPT_SUFFIXES}" |
| ) |
| run_dir = _resolve_run_dir(checkpoint_pt) |
|
|
| |
| config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json" |
| assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" |
| assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" |
|
|
| with open(config_json, "r") as f: |
| global_cfg = json.load(f) |
|
|
| with open(dataset_statistics_json, "r") as f: |
| norm_stats = json.load(f) |
| else: |
| overwatch.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| return global_cfg, norm_stats |
|
|
|
|
| def read_mode_config(pretrained_checkpoint): |
| """ |
| Load YAML model configuration and dataset normalization statistics |
| associated with a saved checkpoint (.pt or .safetensors). |
| |
| Args: |
| pretrained_checkpoint: Path to a .pt or .safetensors checkpoint file. |
| |
| Returns: |
| tuple: |
| vla_cfg (dict) |
| norm_stats (dict) |
| """ |
| if os.path.isfile(pretrained_checkpoint): |
| checkpoint_pt = Path(pretrained_checkpoint) |
| overwatch.info(f"Loading from local checkpoint path `{checkpoint_pt}`") |
|
|
| assert checkpoint_pt.suffix in _SUPPORTED_CKPT_SUFFIXES, ( |
| f"Unsupported checkpoint suffix `{checkpoint_pt.suffix}`, expected one of {_SUPPORTED_CKPT_SUFFIXES}" |
| ) |
| run_dir = _resolve_run_dir(checkpoint_pt) |
|
|
| config_yaml, dataset_statistics_json = run_dir / "config.yaml", run_dir / "dataset_statistics.json" |
| assert config_yaml.exists(), f"Missing `config.yaml` for `{run_dir = }`" |
| assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" |
|
|
| try: |
| ocfg = OmegaConf.load(str(config_yaml)) |
| global_cfg = OmegaConf.to_container(ocfg, resolve=True) |
| except Exception as e: |
| overwatch.error(f"❌ Failed to load YAML config `{config_yaml}`: {e}") |
| raise |
|
|
| with open(dataset_statistics_json, "r") as f: |
| norm_stats = json.load(f) |
| else: |
| overwatch.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| return global_cfg, norm_stats |
|
|