File size: 2,688 Bytes
a8298d5
 
 
 
 
 
e46156a
a8298d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import logging
from pathlib import Path
from typing import Any

import yaml

HUB_REPO_ID = "nvidia/isaaclab-arena-envs"
DEFAULT_CONFIG_FILE = "configs/config.yaml"


def validate_config(
    env,
    state_keys: tuple[str, ...],
    camera_keys: tuple[str, ...],
    cfg_state_dim: int,
    cfg_action_dim: int,
) -> None:
    """Validate observation keys and dimensions against IsaacLab managers."""
    obs_manager = env.observation_manager
    active_terms = obs_manager.active_terms
    policy_terms = set(active_terms.get("policy", []))
    camera_terms = set(active_terms.get("camera_obs", []))

    # Validate keys exist
    missing_state = [k for k in state_keys if k not in policy_terms]
    if missing_state:
        raise ValueError(f"Invalid state_keys: {missing_state}. Available: {sorted(policy_terms)}")

    missing_cam = [k for k in camera_keys if k not in camera_terms]
    if missing_cam:
        raise ValueError(f"Invalid camera_keys: {missing_cam}. Available: {sorted(camera_terms)}")

    # Validate dimensions
    env_action_dim = env.action_space.shape[-1]
    if cfg_action_dim != env_action_dim:
        raise ValueError(f"action_dim mismatch: config={cfg_action_dim}, env={env_action_dim}")

    # Compute expected state dimension
    policy_dims = obs_manager.group_obs_dim.get("policy", [])
    policy_names = active_terms.get("policy", [])
    term_dims = dict(zip(policy_names, policy_dims, strict=False))

    expected_state_dim = 0
    for key in state_keys:
        if key in term_dims:
            shape = term_dims[key]
            dim = 1
            for s in shape if isinstance(shape, (tuple, list)) else [shape]:
                dim *= s
            expected_state_dim += dim

    if cfg_state_dim != expected_state_dim:
        raise ValueError(
            f"state_dim mismatch: config={cfg_state_dim}, "
            f"computed={expected_state_dim}. "
            f"Term dims: {term_dims}"
        )

    logging.info(f"Validated: state_keys={state_keys}, camera_keys={camera_keys}")


def load_config(config_path: str | Path | None = None) -> dict[str, Any]:
    """Load environment config from YAML file.

    Args:
        config_path: Path to YAML config file.
            If None, uses default configs/config.yaml.

    Returns:
        Dictionary with environment configuration.
    """
    # Running from HF cache - download config from Hub
    from huggingface_hub import hf_hub_download

    config_path = hf_hub_download(
        repo_id=HUB_REPO_ID, filename=DEFAULT_CONFIG_FILE
    )

    logging.info(f"Loading config from: {config_path}")
    with open(config_path) as f:
        config = yaml.safe_load(f) or {}

    return config