| """Configuration loader for ReMDM-MiniHack. |
| |
| Loads YAML configs with deep-merge and CLI override support, |
| following the Craftax config pattern. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import secrets |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import yaml |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| _PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _deep_merge(base: dict, override: dict) -> dict: |
| """Recursively merge *override* into *base* (mutates *base*). |
| |
| Args: |
| base: Base dictionary to merge into. |
| override: Dictionary whose values take precedence. |
| |
| Returns: |
| The merged dictionary (same object as *base*). |
| """ |
| for key, value in override.items(): |
| if ( |
| key in base |
| and isinstance(base[key], dict) |
| and isinstance(value, dict) |
| ): |
| _deep_merge(base[key], value) |
| else: |
| base[key] = value |
| return base |
|
|
|
|
| def _cast_value(value: str) -> int | float | bool | str | None: |
| """Best-effort cast of a CLI string to a Python scalar. |
| |
| Args: |
| value: Raw string from the command line. |
| |
| Returns: |
| Parsed Python value (int, float, bool, str, or None). |
| """ |
| if value.lower() in ("true", "yes"): |
| return True |
| if value.lower() in ("false", "no"): |
| return False |
| if value.lower() == "null": |
| return None |
| try: |
| return int(value) |
| except ValueError: |
| pass |
| try: |
| return float(value) |
| except ValueError: |
| pass |
| return value |
|
|
|
|
| def load_config( |
| config_path: str | None = None, |
| cli_overrides: dict | None = None, |
| ) -> SimpleNamespace: |
| """Load configuration from YAML with optional overrides. |
| |
| 1. Load ``configs/defaults.yaml``. |
| 2. Deep-merge *config_path* on top (if provided and different from defaults). |
| 3. Apply *cli_overrides* key=value pairs. |
| 4. Auto-select device (``cuda`` if available, else ``cpu``; honour |
| ``DEVICE`` env-var). |
| 5. Validate invariants. |
| |
| Args: |
| config_path: Path to a YAML file merged on top of defaults. |
| ``None`` uses defaults only. |
| cli_overrides: ``{key: value}`` pairs applied last. |
| |
| Returns: |
| A ``SimpleNamespace`` containing all hyperparameters. |
| |
| Raises: |
| AssertionError: If ``mask_token != action_dim`` or |
| ``pad_token != action_dim + 1``. |
| """ |
| if cli_overrides is None: |
| cli_overrides = {} |
|
|
| defaults_path = _PROJECT_ROOT / "configs" / "defaults.yaml" |
| with open(defaults_path, "r") as fh: |
| cfg = yaml.safe_load(fh) |
|
|
| if config_path is not None: |
| config_path_resolved = Path(config_path) |
| if not config_path_resolved.is_absolute(): |
| config_path_resolved = _PROJECT_ROOT / config_path_resolved |
| if config_path_resolved.resolve() != defaults_path.resolve(): |
| with open(config_path_resolved, "r") as fh: |
| overrides = yaml.safe_load(fh) or {} |
| _deep_merge(cfg, overrides) |
|
|
| for key, value in cli_overrides.items(): |
| if isinstance(value, str): |
| value = _cast_value(value) |
| cfg[key] = value |
|
|
| |
| env_device = os.environ.get("DEVICE") |
| if env_device: |
| cfg["device"] = env_device |
| elif "device" not in cfg: |
| try: |
| import torch |
| cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" |
| except ImportError: |
| cfg["device"] = "cpu" |
|
|
| ns = SimpleNamespace(**cfg) |
|
|
| |
| assert ns.mask_token == ns.action_dim, ( |
| f"mask_token ({ns.mask_token}) must equal action_dim ({ns.action_dim})" |
| ) |
| assert ns.pad_token == ns.action_dim + 1, ( |
| f"pad_token ({ns.pad_token}) must equal action_dim + 1 " |
| f"({ns.action_dim + 1})" |
| ) |
|
|
| return ns |
|
|
|
|
| def make_run_dir(cfg: SimpleNamespace, tag: str = "run") -> Path: |
| """Create a unique run subdirectory under ``cfg.checkpoint_dir``. |
| |
| Generates a directory named ``{tag}_{YYYYMMDD}_{HHMMSS}_{hex4}`` |
| to prevent concurrent runs from overwriting each other's |
| checkpoints. Updates ``cfg.checkpoint_dir`` in place. |
| |
| Args: |
| cfg: Config namespace (``checkpoint_dir`` is mutated). |
| tag: Prefix for the directory name (e.g. ``"dagger"``, |
| ``"offline"``). |
| |
| Returns: |
| The created directory path. |
| """ |
| ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") |
| suffix = secrets.token_hex(2) |
| run_dir = Path(cfg.checkpoint_dir).resolve() / f"{tag}_{ts}_{suffix}" |
| run_dir.mkdir(parents=True, exist_ok=True) |
| cfg.checkpoint_dir = str(run_dir) |
| logger.info("Checkpoint directory: %s", run_dir) |
| return run_dir |
|
|