Make env.py self-contained
Browse files
env.py
CHANGED
|
@@ -26,15 +26,64 @@ from typing import Any
|
|
| 26 |
|
| 27 |
import gymnasium as gym
|
| 28 |
import numpy as np
|
| 29 |
-
import sys
|
| 30 |
import torch
|
| 31 |
import yaml
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Module path for environment class resolution
|
| 40 |
ISAACLAB_ARENA_ENV_MODULE = os.environ.get("ISAACLAB_ARENA_ENV_MODULE", "isaaclab_arena_environments")
|
|
@@ -246,9 +295,8 @@ class IsaacLabVectorEnvWrapper:
|
|
| 246 |
self.close()
|
| 247 |
|
| 248 |
|
| 249 |
-
|
| 250 |
def _find_config_file(config_path: str | Path | None = None) -> Path | None:
|
| 251 |
-
"""Find config.yaml file in various locations.
|
| 252 |
|
| 253 |
Search order:
|
| 254 |
1. Explicit path from ISAACLAB_ARENA_CONFIG_PATH env var
|
|
@@ -257,6 +305,7 @@ def _find_config_file(config_path: str | Path | None = None) -> Path | None:
|
|
| 257 |
4. config.yaml relative to this module
|
| 258 |
5. configs/config.yaml in current directory
|
| 259 |
6. config.yaml in current directory
|
|
|
|
| 260 |
|
| 261 |
Returns:
|
| 262 |
Path to config file, or None if not found (uses defaults).
|
|
@@ -286,6 +335,19 @@ def _find_config_file(config_path: str | Path | None = None) -> Path | None:
|
|
| 286 |
if path.exists():
|
| 287 |
return path
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
logging.info("No config.yaml found, using defaults")
|
| 290 |
return None
|
| 291 |
|
|
|
|
| 26 |
|
| 27 |
import gymnasium as gym
|
| 28 |
import numpy as np
|
|
|
|
| 29 |
import torch
|
| 30 |
import yaml
|
| 31 |
|
| 32 |
+
# Hub constants for downloading additional files
|
| 33 |
+
HUB_REPO_ID = "nvkartik/isaaclab-arena-envs"
|
| 34 |
+
DEFAULT_CONFIG_FILE = "configs/config.yaml"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def validate_config(
|
| 38 |
+
env,
|
| 39 |
+
state_keys: tuple[str, ...],
|
| 40 |
+
camera_keys: tuple[str, ...],
|
| 41 |
+
cfg_state_dim: int,
|
| 42 |
+
cfg_action_dim: int,
|
| 43 |
+
) -> None:
|
| 44 |
+
"""Validate observation keys and dimensions against IsaacLab managers."""
|
| 45 |
+
obs_manager = env.observation_manager
|
| 46 |
+
active_terms = obs_manager.active_terms
|
| 47 |
+
policy_terms = set(active_terms.get("policy", []))
|
| 48 |
+
camera_terms = set(active_terms.get("camera_obs", []))
|
| 49 |
+
|
| 50 |
+
# Validate keys exist
|
| 51 |
+
missing_state = [k for k in state_keys if k not in policy_terms]
|
| 52 |
+
if missing_state:
|
| 53 |
+
raise ValueError(f"Invalid state_keys: {missing_state}. Available: {sorted(policy_terms)}")
|
| 54 |
+
|
| 55 |
+
missing_cam = [k for k in camera_keys if k not in camera_terms]
|
| 56 |
+
if missing_cam:
|
| 57 |
+
raise ValueError(f"Invalid camera_keys: {missing_cam}. Available: {sorted(camera_terms)}")
|
| 58 |
+
|
| 59 |
+
# Validate dimensions
|
| 60 |
+
env_action_dim = env.action_space.shape[-1]
|
| 61 |
+
if cfg_action_dim != env_action_dim:
|
| 62 |
+
raise ValueError(f"action_dim mismatch: config={cfg_action_dim}, env={env_action_dim}")
|
| 63 |
+
|
| 64 |
+
# Compute expected state dimension
|
| 65 |
+
policy_dims = obs_manager.group_obs_dim.get("policy", [])
|
| 66 |
+
policy_names = active_terms.get("policy", [])
|
| 67 |
+
term_dims = dict(zip(policy_names, policy_dims, strict=False))
|
| 68 |
+
|
| 69 |
+
expected_state_dim = 0
|
| 70 |
+
for key in state_keys:
|
| 71 |
+
if key in term_dims:
|
| 72 |
+
shape = term_dims[key]
|
| 73 |
+
dim = 1
|
| 74 |
+
for s in shape if isinstance(shape, (tuple, list)) else [shape]:
|
| 75 |
+
dim *= s
|
| 76 |
+
expected_state_dim += dim
|
| 77 |
+
|
| 78 |
+
if cfg_state_dim != expected_state_dim:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"state_dim mismatch: config={cfg_state_dim}, "
|
| 81 |
+
f"computed={expected_state_dim}. "
|
| 82 |
+
f"Term dims: {term_dims}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
logging.info(f"Validated: state_keys={state_keys}, camera_keys={camera_keys}")
|
| 86 |
+
|
| 87 |
|
| 88 |
# Module path for environment class resolution
|
| 89 |
ISAACLAB_ARENA_ENV_MODULE = os.environ.get("ISAACLAB_ARENA_ENV_MODULE", "isaaclab_arena_environments")
|
|
|
|
| 295 |
self.close()
|
| 296 |
|
| 297 |
|
|
|
|
| 298 |
def _find_config_file(config_path: str | Path | None = None) -> Path | None:
|
| 299 |
+
"""Find config.yaml file in various locations, downloading from Hub if needed.
|
| 300 |
|
| 301 |
Search order:
|
| 302 |
1. Explicit path from ISAACLAB_ARENA_CONFIG_PATH env var
|
|
|
|
| 305 |
4. config.yaml relative to this module
|
| 306 |
5. configs/config.yaml in current directory
|
| 307 |
6. config.yaml in current directory
|
| 308 |
+
7. Download from Hugging Face Hub (fallback for Hub-loaded env.py)
|
| 309 |
|
| 310 |
Returns:
|
| 311 |
Path to config file, or None if not found (uses defaults).
|
|
|
|
| 335 |
if path.exists():
|
| 336 |
return path
|
| 337 |
|
| 338 |
+
# Fallback: download config from Hugging Face Hub
|
| 339 |
+
# This is needed when env.py is loaded from HF cache (only env.py is downloaded)
|
| 340 |
+
try:
|
| 341 |
+
from huggingface_hub import hf_hub_download
|
| 342 |
+
logging.info(f"Downloading config from Hub: {HUB_REPO_ID}/{DEFAULT_CONFIG_FILE}")
|
| 343 |
+
hub_config_path = hf_hub_download(
|
| 344 |
+
repo_id=HUB_REPO_ID,
|
| 345 |
+
filename=DEFAULT_CONFIG_FILE,
|
| 346 |
+
)
|
| 347 |
+
return Path(hub_config_path)
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logging.warning(f"Failed to download config from Hub: {e}")
|
| 350 |
+
|
| 351 |
logging.info("No config.yaml found, using defaults")
|
| 352 |
return None
|
| 353 |
|