nvkartik commited on
Commit
76ba10d
·
1 Parent(s): 533a63d

Make env.py self-contained

Browse files
Files changed (1) hide show
  1. env.py +70 -8
env.py CHANGED
@@ -26,15 +26,64 @@ from typing import Any
26
 
27
  import gymnasium as gym
28
  import numpy as np
29
- import sys
30
  import torch
31
  import yaml
32
 
33
- # Import utils from the same directory as this file (works when loaded from HF Hub cache)
34
- _this_dir = os.path.dirname(os.path.abspath(__file__))
35
- if _this_dir not in sys.path:
36
- sys.path.insert(0, _this_dir)
37
- from utils import validate_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Module path for environment class resolution
40
  ISAACLAB_ARENA_ENV_MODULE = os.environ.get("ISAACLAB_ARENA_ENV_MODULE", "isaaclab_arena_environments")
@@ -246,9 +295,8 @@ class IsaacLabVectorEnvWrapper:
246
  self.close()
247
 
248
 
249
-
250
  def _find_config_file(config_path: str | Path | None = None) -> Path | None:
251
- """Find config.yaml file in various locations.
252
 
253
  Search order:
254
  1. Explicit path from ISAACLAB_ARENA_CONFIG_PATH env var
@@ -257,6 +305,7 @@ def _find_config_file(config_path: str | Path | None = None) -> Path | None:
257
  4. config.yaml relative to this module
258
  5. configs/config.yaml in current directory
259
  6. config.yaml in current directory
 
260
 
261
  Returns:
262
  Path to config file, or None if not found (uses defaults).
@@ -286,6 +335,19 @@ def _find_config_file(config_path: str | Path | None = None) -> Path | None:
286
  if path.exists():
287
  return path
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  logging.info("No config.yaml found, using defaults")
290
  return None
291
 
 
26
 
27
  import gymnasium as gym
28
  import numpy as np
 
29
  import torch
30
  import yaml
31
 
32
+ # Hub constants for downloading additional files
33
+ HUB_REPO_ID = "nvkartik/isaaclab-arena-envs"
34
+ DEFAULT_CONFIG_FILE = "configs/config.yaml"
35
+
36
+
37
+ def validate_config(
38
+ env,
39
+ state_keys: tuple[str, ...],
40
+ camera_keys: tuple[str, ...],
41
+ cfg_state_dim: int,
42
+ cfg_action_dim: int,
43
+ ) -> None:
44
+ """Validate observation keys and dimensions against IsaacLab managers."""
45
+ obs_manager = env.observation_manager
46
+ active_terms = obs_manager.active_terms
47
+ policy_terms = set(active_terms.get("policy", []))
48
+ camera_terms = set(active_terms.get("camera_obs", []))
49
+
50
+ # Validate keys exist
51
+ missing_state = [k for k in state_keys if k not in policy_terms]
52
+ if missing_state:
53
+ raise ValueError(f"Invalid state_keys: {missing_state}. Available: {sorted(policy_terms)}")
54
+
55
+ missing_cam = [k for k in camera_keys if k not in camera_terms]
56
+ if missing_cam:
57
+ raise ValueError(f"Invalid camera_keys: {missing_cam}. Available: {sorted(camera_terms)}")
58
+
59
+ # Validate dimensions
60
+ env_action_dim = env.action_space.shape[-1]
61
+ if cfg_action_dim != env_action_dim:
62
+ raise ValueError(f"action_dim mismatch: config={cfg_action_dim}, env={env_action_dim}")
63
+
64
+ # Compute expected state dimension
65
+ policy_dims = obs_manager.group_obs_dim.get("policy", [])
66
+ policy_names = active_terms.get("policy", [])
67
+ term_dims = dict(zip(policy_names, policy_dims, strict=False))
68
+
69
+ expected_state_dim = 0
70
+ for key in state_keys:
71
+ if key in term_dims:
72
+ shape = term_dims[key]
73
+ dim = 1
74
+ for s in shape if isinstance(shape, (tuple, list)) else [shape]:
75
+ dim *= s
76
+ expected_state_dim += dim
77
+
78
+ if cfg_state_dim != expected_state_dim:
79
+ raise ValueError(
80
+ f"state_dim mismatch: config={cfg_state_dim}, "
81
+ f"computed={expected_state_dim}. "
82
+ f"Term dims: {term_dims}"
83
+ )
84
+
85
+ logging.info(f"Validated: state_keys={state_keys}, camera_keys={camera_keys}")
86
+
87
 
88
  # Module path for environment class resolution
89
  ISAACLAB_ARENA_ENV_MODULE = os.environ.get("ISAACLAB_ARENA_ENV_MODULE", "isaaclab_arena_environments")
 
295
  self.close()
296
 
297
 
 
298
  def _find_config_file(config_path: str | Path | None = None) -> Path | None:
299
+ """Find config.yaml file in various locations, downloading from Hub if needed.
300
 
301
  Search order:
302
  1. Explicit path from ISAACLAB_ARENA_CONFIG_PATH env var
 
305
  4. config.yaml relative to this module
306
  5. configs/config.yaml in current directory
307
  6. config.yaml in current directory
308
+ 7. Download from Hugging Face Hub (fallback for Hub-loaded env.py)
309
 
310
  Returns:
311
  Path to config file, or None if not found (uses defaults).
 
335
  if path.exists():
336
  return path
337
 
338
+ # Fallback: download config from Hugging Face Hub
339
+ # This is needed when env.py is loaded from HF cache (only env.py is downloaded)
340
+ try:
341
+ from huggingface_hub import hf_hub_download
342
+ logging.info(f"Downloading config from Hub: {HUB_REPO_ID}/{DEFAULT_CONFIG_FILE}")
343
+ hub_config_path = hf_hub_download(
344
+ repo_id=HUB_REPO_ID,
345
+ filename=DEFAULT_CONFIG_FILE,
346
+ )
347
+ return Path(hub_config_path)
348
+ except Exception as e:
349
+ logging.warning(f"Failed to download config from Hub: {e}")
350
+
351
  logging.info("No config.yaml found, using defaults")
352
  return None
353