tianheng.wu commited on
Commit
8c5a841
·
1 Parent(s): 99cd48a

[feat] move IsaacLabEnvWrapper to EnvHub

Browse files
Files changed (3) hide show
  1. env.py +1 -1
  2. error.py +30 -0
  3. isaaclab_env_wrapper.py +220 -0
env.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  from typing import Any
3
- from lerobot.envs.isaaclab import IsaacLabEnvWrapper
4
 
5
 
6
  def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs: Any):
 
1
  import logging
2
  from typing import Any
3
+ from .isaaclab_env_wrapper import IsaacLabEnvWrapper
4
 
5
 
6
  def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs: Any):
error.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class IsaacLabArenaError(RuntimeError):
2
+ """Base exception for IsaacLab Arena environment errors."""
3
+
4
+ def __init__(self, message: str = "IsaacLab Arena error"):
5
+ self.message = message
6
+ super().__init__(self.message)
7
+
8
+
9
+ class IsaacLabArenaConfigError(IsaacLabArenaError):
10
+ """Exception raised for invalid environment configuration."""
11
+
12
+ def __init__(self, invalid: list, available: list, key_type: str = "keys"):
13
+ msg = f"Invalid {key_type}: {invalid}. Available: {sorted(available)}"
14
+ super().__init__(msg)
15
+ self.invalid = invalid
16
+ self.available = available
17
+
18
+
19
+ class IsaacLabArenaCameraKeyError(IsaacLabArenaConfigError):
20
+ """Exception raised when camera_keys don't match available cameras."""
21
+
22
+ def __init__(self, invalid: list, available: list):
23
+ super().__init__(invalid, available, "camera_keys")
24
+
25
+
26
+ class IsaacLabArenaStateKeyError(IsaacLabArenaConfigError):
27
+ """Exception raised when state_keys don't match available state terms."""
28
+
29
+ def __init__(self, invalid: list, available: list):
30
+ super().__init__(invalid, available, "state_keys")
isaaclab_env_wrapper.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import signal
7
+ from contextlib import suppress
8
+ from typing import Any
9
+
10
+ import gymnasium as gym
11
+ import numpy as np
12
+ import torch
13
+
14
+ from .errors import IsaacLabArenaError
15
+
16
+
17
+ def cleanup_isaaclab(env, simulation_app) -> None:
18
+ """Cleanup IsaacLab env and simulation app resources."""
19
+ # Ignore signals during cleanup to prevent interruption
20
+ old_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
21
+ old_sigterm = signal.signal(signal.SIGTERM, signal.SIG_IGN)
22
+ try:
23
+ with suppress(Exception):
24
+ if env is not None:
25
+ env.close()
26
+ with suppress(Exception):
27
+ if simulation_app is not None:
28
+ simulation_app.app.close()
29
+ finally:
30
+ # Restore signal handlers
31
+ signal.signal(signal.SIGINT, old_sigint)
32
+ signal.signal(signal.SIGTERM, old_sigterm)
33
+
34
+
35
+ class IsaacLabEnvWrapper(gym.vector.AsyncVectorEnv):
36
+ """Wrapper adapting IsaacLab batched GPU env to AsyncVectorEnv.
37
+ IsaacLab handles vectorization internally on GPU. We inherit from
38
+ AsyncVectorEnv for compatibility with LeRobot."""
39
+
40
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
41
+ _cleanup_in_progress = False # Class-level flag for re-entrant protection
42
+
43
+ def __init__(
44
+ self,
45
+ env,
46
+ episode_length: int = 500,
47
+ task: str | None = None,
48
+ render_mode: str | None = "rgb_array",
49
+ simulation_app=None,
50
+ ):
51
+ self._env = env
52
+ self._num_envs = env.num_envs
53
+ self._episode_length = episode_length
54
+ self._closed = False
55
+ self.render_mode = render_mode
56
+ self._simulation_app = simulation_app
57
+
58
+ self.observation_space = env.observation_space
59
+ self.action_space = env.action_space
60
+ self.single_observation_space = env.observation_space
61
+ self.single_action_space = env.action_space
62
+ self.task = task
63
+
64
+ if hasattr(env, "metadata") and env.metadata:
65
+ self.metadata = {**self.metadata, **env.metadata}
66
+
67
+ # Register cleanup handlers
68
+ atexit.register(self._cleanup)
69
+ signal.signal(signal.SIGINT, self._signal_handler)
70
+ signal.signal(signal.SIGTERM, self._signal_handler)
71
+
72
+ def _signal_handler(self, signum, frame):
73
+ if IsaacLabEnvWrapper._cleanup_in_progress:
74
+ return # Prevent re-entrant cleanup
75
+ IsaacLabEnvWrapper._cleanup_in_progress = True
76
+ logging.info(f"Received signal {signum}, cleaning up...")
77
+ self._cleanup()
78
+ # Exit without raising to avoid propagating through callbacks
79
+ os._exit(0)
80
+
81
+ def _check_closed(self):
82
+ if self._closed:
83
+ raise IsaacLabArenaError()
84
+
85
+ @property
86
+ def unwrapped(self):
87
+ return self
88
+
89
+ @property
90
+ def num_envs(self) -> int:
91
+ return self._num_envs
92
+
93
+ @property
94
+ def _max_episode_steps(self) -> int:
95
+ return self._episode_length
96
+
97
+ @property
98
+ def device(self) -> str:
99
+ return getattr(self._env, "device", "cpu")
100
+
101
+ def reset(
102
+ self,
103
+ *,
104
+ seed: int | list[int] | None = None,
105
+ options: dict[str, Any] | None = None,
106
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
107
+ self._check_closed()
108
+ if isinstance(seed, (list, tuple, range)):
109
+ seed = seed[0] if len(seed) > 0 else None
110
+
111
+ obs, info = self._env.reset(seed=seed, options=options)
112
+
113
+ if "final_info" not in info:
114
+ zeros = np.zeros(self._num_envs, dtype=bool)
115
+ info["final_info"] = {"is_success": zeros}
116
+
117
+ return obs, info
118
+
119
+ def step(
120
+ self, actions: np.ndarray | torch.Tensor
121
+ ) -> tuple[dict, np.ndarray, np.ndarray, np.ndarray, dict]:
122
+ self._check_closed()
123
+ if isinstance(actions, np.ndarray):
124
+ actions = torch.from_numpy(actions).to(self._env.device)
125
+
126
+ obs, reward, terminated, truncated, info = self._env.step(actions)
127
+
128
+ # Convert to numpy for gym compatibility
129
+ reward = reward.cpu().numpy().astype(np.float32)
130
+ terminated = terminated.cpu().numpy().astype(bool)
131
+ truncated = truncated.cpu().numpy().astype(bool)
132
+
133
+ is_success = self._get_success(terminated, truncated)
134
+ info["final_info"] = {"is_success": is_success}
135
+
136
+ return obs, reward, terminated, truncated, info
137
+
138
+ def _get_success(self, terminated: np.ndarray, truncated: np.ndarray) -> np.ndarray:
139
+ is_success = np.zeros(self._num_envs, dtype=bool)
140
+
141
+ if not hasattr(self._env, "termination_manager"):
142
+ return is_success & (terminated | truncated)
143
+
144
+ term_manager = self._env.termination_manager
145
+ if not hasattr(term_manager, "get_term"):
146
+ return is_success & (terminated | truncated)
147
+
148
+ success_tensor = term_manager.get_term("success")
149
+ if success_tensor is None:
150
+ return is_success & (terminated | truncated)
151
+
152
+ is_success = success_tensor.cpu().numpy().astype(bool)
153
+
154
+ return is_success & (terminated | truncated)
155
+
156
+ def call(self, method_name: str, *args, **kwargs) -> list[Any]:
157
+ if method_name == "_max_episode_steps":
158
+ return [self._episode_length] * self._num_envs
159
+ if method_name == "task":
160
+ return [self.task] * self._num_envs
161
+ if method_name == "render":
162
+ return self.render_all()
163
+
164
+ if hasattr(self._env, method_name):
165
+ attr = getattr(self._env, method_name)
166
+ result = attr(*args, **kwargs) if callable(attr) else attr
167
+ if isinstance(result, list):
168
+ return result
169
+ return [result] * self._num_envs
170
+
171
+ raise AttributeError(f"IsaacLab-Arena has no method/attribute '{method_name}'")
172
+
173
+ def render_all(self) -> list[np.ndarray]:
174
+ self._check_closed()
175
+ frames = self.render()
176
+ if frames is None:
177
+ placeholder = np.zeros((480, 640, 3), dtype=np.uint8)
178
+ return [placeholder] * self._num_envs
179
+
180
+ return [frames] * self._num_envs
181
+
182
+ def render(self) -> np.ndarray | None:
183
+ """Render all environments and return list of frames."""
184
+ self._check_closed()
185
+ if self.render_mode != "rgb_array":
186
+ return None
187
+
188
+ frames = self._env.render() if hasattr(self._env, "render") else None
189
+ if frames is None:
190
+ return None
191
+
192
+ if isinstance(frames, torch.Tensor):
193
+ frames = frames.cpu().numpy()
194
+
195
+ return frames[0] if frames.ndim == 4 else frames
196
+
197
+ def _cleanup(self) -> None:
198
+ if self._closed:
199
+ return
200
+ self._closed = True
201
+ IsaacLabEnvWrapper._cleanup_in_progress = True
202
+ logging.info("Cleaning up IsaacLab Arena environment...")
203
+ cleanup_isaaclab(self._env, self._simulation_app)
204
+
205
+ def close(self) -> None:
206
+ self._cleanup()
207
+
208
+ @property
209
+ def envs(self) -> list[IsaacLabEnvWrapper]:
210
+ return [self] * self._num_envs
211
+
212
+ def __del__(self):
213
+ self._cleanup()
214
+
215
+ def __enter__(self):
216
+ return self
217
+
218
+ def __exit__(self, exc_type, exc_val, exc_tb):
219
+ self._cleanup()
220
+ return False