lsnu commited on
Commit
20ce2c0
·
verified ·
1 Parent(s): 1f962bd

Add files using upload-large-folder tool

Browse files
code/reveal_vla_bimanual/eval/run_rlbench_knn_eval.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader, Subset
12
+
13
+ from eval.run_rlbench_rollout_eval import (
14
+ BimanualEndEffectorPoseViaIK,
15
+ _episode_language_goal,
16
+ _load_compatible_state_dict,
17
+ _policy_config_from_checkpoint,
18
+ _reset_task_with_retries,
19
+ _step_bimanual_chunk,
20
+ _trainer_config_from_checkpoint,
21
+ )
22
+ from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper
23
+ from rlbench.action_modes.arm_action_modes import BimanualEndEffectorPoseViaPlanning
24
+ from rlbench.action_modes.gripper_action_modes import BimanualDiscrete
25
+ from rlbench.environment import Environment
26
+ from sim_rlbench.camera_spec import default_three_camera_spec
27
+ from sim_rlbench.dataset import RLBenchOfflineChunkDataset, bimanual_proprio_from_obs, stack_live_rgb_obs
28
+ from sim_rlbench.obs_config import build_obs_config
29
+ from sim_rlbench.task_resolver import resolve_task_class
30
+ from train.trainer import build_policy
31
+
32
+
33
+ def _make_bank_loader(dataset: RLBenchOfflineChunkDataset, bank_stride: int, batch_size: int, num_workers: int) -> DataLoader:
34
+ indices = list(range(0, len(dataset), max(1, bank_stride)))
35
+ subset = Subset(dataset, indices)
36
+ return DataLoader(
37
+ subset,
38
+ batch_size=batch_size,
39
+ shuffle=False,
40
+ num_workers=num_workers,
41
+ pin_memory=torch.cuda.is_available(),
42
+ )
43
+
44
+
45
+ def _encode_bank(
46
+ model: torch.nn.Module,
47
+ dataset: RLBenchOfflineChunkDataset,
48
+ device: torch.device,
49
+ batch_size: int,
50
+ bank_stride: int,
51
+ num_workers: int,
52
+ ) -> dict[str, torch.Tensor]:
53
+ loader = _make_bank_loader(dataset, bank_stride=bank_stride, batch_size=batch_size, num_workers=num_workers)
54
+ feature_chunks: list[torch.Tensor] = []
55
+ action_chunks: list[torch.Tensor] = []
56
+ step_chunks: list[torch.Tensor] = []
57
+ with torch.no_grad():
58
+ for batch in loader:
59
+ images = batch["images"].to(device)
60
+ proprio = batch["proprio"].to(device)
61
+ texts = list(batch["texts"])
62
+ scene_tokens = model.encode_scene(images, proprio, texts=texts)
63
+ pooled = F.normalize(scene_tokens.mean(dim=1), dim=-1)
64
+ feature_chunks.append(pooled.cpu())
65
+ action_chunks.append(batch["action_chunk"][:, 0].cpu())
66
+ step_chunks.append(batch["step_index"].cpu())
67
+ return {
68
+ "features": torch.cat(feature_chunks, dim=0),
69
+ "actions": torch.cat(action_chunks, dim=0),
70
+ "steps": torch.cat(step_chunks, dim=0),
71
+ }
72
+
73
+
74
+ def _choose_action(
75
+ bank: dict[str, torch.Tensor],
76
+ query_feature: torch.Tensor,
77
+ timestep: int,
78
+ top_k: int,
79
+ time_window: int,
80
+ ) -> np.ndarray:
81
+ features = bank["features"]
82
+ actions = bank["actions"]
83
+ steps = bank["steps"]
84
+ if time_window >= 0:
85
+ mask = (steps - int(timestep)).abs() <= int(time_window)
86
+ if mask.any():
87
+ features = features[mask]
88
+ actions = actions[mask]
89
+ similarities = torch.matmul(features, query_feature.cpu())
90
+ k = min(int(top_k), similarities.numel())
91
+ top_values, top_indices = torch.topk(similarities, k=k, largest=True)
92
+ top_actions = actions[top_indices]
93
+ weights = torch.softmax(top_values.float(), dim=0).unsqueeze(-1)
94
+ return torch.sum(top_actions.float() * weights, dim=0).numpy().astype(np.float32)
95
+
96
+
97
+ def main() -> None:
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument("--checkpoint", required=True)
100
+ parser.add_argument("--output-dir", required=True)
101
+ parser.add_argument("--task", required=True)
102
+ parser.add_argument("--train-episodes", nargs="+", type=int, required=True)
103
+ parser.add_argument("--episodes-per-task", type=int, default=1)
104
+ parser.add_argument("--episode-length", type=int, default=180)
105
+ parser.add_argument("--resolution", type=int, default=224)
106
+ parser.add_argument("--device", default="cuda")
107
+ parser.add_argument("--arm-mode", choices=("planning", "ik"), default="ik")
108
+ parser.add_argument("--delta-scale", type=float, default=1.0)
109
+ parser.add_argument("--bank-batch-size", type=int, default=32)
110
+ parser.add_argument("--bank-stride", type=int, default=4)
111
+ parser.add_argument("--bank-num-workers", type=int, default=4)
112
+ parser.add_argument("--top-k", type=int, default=5)
113
+ parser.add_argument("--time-window", type=int, default=8)
114
+ parser.add_argument("--reset-retries", type=int, default=20)
115
+ parser.add_argument("--headless", action="store_true", default=True)
116
+ args = parser.parse_args()
117
+
118
+ checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
119
+ policy_config = _policy_config_from_checkpoint(checkpoint)
120
+ trainer_config = _trainer_config_from_checkpoint(checkpoint)
121
+ device = torch.device("cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu")
122
+ model = build_policy(policy_config, trainer_config).to(device)
123
+ incompatible, skipped_shape_mismatches = _load_compatible_state_dict(model, checkpoint["state_dict"])
124
+ if incompatible.unexpected_keys or incompatible.missing_keys:
125
+ raise RuntimeError(
126
+ f"Checkpoint incompatibility for kNN eval. Missing={list(incompatible.missing_keys)} unexpected={list(incompatible.unexpected_keys)}"
127
+ )
128
+ if skipped_shape_mismatches:
129
+ raise RuntimeError(f"kNN eval does not support shape-mismatch loads: {skipped_shape_mismatches}")
130
+ model.eval()
131
+
132
+ bank_dataset = RLBenchOfflineChunkDataset(
133
+ dataset_root="/workspace/data/rlbench2",
134
+ tasks=[args.task],
135
+ episode_indices=args.train_episodes,
136
+ resolution=args.resolution,
137
+ chunk_size=policy_config.decoder.chunk_size,
138
+ proprio_dim=policy_config.fusion.proprio_dim,
139
+ history_steps=policy_config.memory.history_steps,
140
+ )
141
+ bank = _encode_bank(
142
+ model=model,
143
+ dataset=bank_dataset,
144
+ device=device,
145
+ batch_size=args.bank_batch_size,
146
+ bank_stride=args.bank_stride,
147
+ num_workers=args.bank_num_workers,
148
+ )
149
+
150
+ camera_spec = default_three_camera_spec(args.resolution)
151
+ task_class = resolve_task_class(args.task)
152
+ obs_config = build_obs_config(list(camera_spec.upstream_cameras), args.resolution)
153
+ if args.arm_mode == "ik":
154
+ arm_action_mode: Any = BimanualEndEffectorPoseViaIK(absolute_mode=True, frame="world", collision_checking=False)
155
+ else:
156
+ arm_action_mode = BimanualEndEffectorPoseViaPlanning(absolute_mode=True, frame="world", collision_checking=False)
157
+ action_mode = BimanualMoveArmThenGripper(arm_action_mode, BimanualDiscrete())
158
+ env = Environment(
159
+ action_mode=action_mode,
160
+ obs_config=obs_config,
161
+ headless=args.headless,
162
+ robot_setup="dual_panda",
163
+ )
164
+ output_dir = Path(args.output_dir)
165
+ output_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ results: dict[str, Any] = {
168
+ "checkpoint": str(Path(args.checkpoint).resolve()),
169
+ "task": args.task,
170
+ "train_episodes": list(args.train_episodes),
171
+ "episodes_per_task": args.episodes_per_task,
172
+ "episode_length": args.episode_length,
173
+ "resolution": args.resolution,
174
+ "arm_mode": args.arm_mode,
175
+ "delta_scale": args.delta_scale,
176
+ "bank_stride": args.bank_stride,
177
+ "top_k": args.top_k,
178
+ "time_window": args.time_window,
179
+ "bank_size": int(bank["features"].shape[0]),
180
+ }
181
+ env.launch()
182
+ try:
183
+ task = env.get_task(task_class)
184
+ successes: list[float] = []
185
+ returns: list[float] = []
186
+ path_recoveries: list[int] = []
187
+ noop_fallbacks: list[int] = []
188
+ episode_errors: list[str | None] = []
189
+ for _ in range(args.episodes_per_task):
190
+ total_reward = 0.0
191
+ success = 0.0
192
+ episode_recoveries = 0
193
+ episode_noops = 0
194
+ episode_error: str | None = None
195
+ try:
196
+ descriptions, obs, _reset_count = _reset_task_with_retries(task, max_attempts=max(1, args.reset_retries))
197
+ language_goal = _episode_language_goal(descriptions)
198
+ for timestep in range(args.episode_length):
199
+ images = stack_live_rgb_obs(obs, resolution=args.resolution).unsqueeze(0).to(device)
200
+ proprio = torch.from_numpy(
201
+ bimanual_proprio_from_obs(
202
+ obs,
203
+ timestep=timestep,
204
+ episode_length=args.episode_length,
205
+ target_dim=policy_config.fusion.proprio_dim,
206
+ )
207
+ ).unsqueeze(0).to(device)
208
+ with torch.no_grad():
209
+ scene_tokens = model.encode_scene(images, proprio, texts=[language_goal])
210
+ query_feature = F.normalize(scene_tokens.mean(dim=1), dim=-1)[0]
211
+ step_action = _choose_action(
212
+ bank=bank,
213
+ query_feature=query_feature,
214
+ timestep=timestep,
215
+ top_k=args.top_k,
216
+ time_window=args.time_window,
217
+ )
218
+ obs, reward, done, recovered_steps, noop_count = _step_bimanual_chunk(
219
+ task,
220
+ obs,
221
+ step_action,
222
+ delta_scale=args.delta_scale,
223
+ )
224
+ episode_recoveries += int(recovered_steps)
225
+ episode_noops += int(noop_count)
226
+ total_reward += float(reward)
227
+ if reward >= 1.0:
228
+ success = 1.0
229
+ if done or success >= 1.0:
230
+ break
231
+ except Exception as exc: # pragma: no cover - live RLBench failure path
232
+ episode_error = str(exc)
233
+ successes.append(success)
234
+ returns.append(total_reward)
235
+ path_recoveries.append(episode_recoveries)
236
+ noop_fallbacks.append(episode_noops)
237
+ episode_errors.append(episode_error)
238
+ results["successes"] = successes
239
+ results["returns"] = returns
240
+ results["path_recoveries"] = path_recoveries
241
+ results["noop_fallbacks"] = noop_fallbacks
242
+ results["episode_errors"] = episode_errors
243
+ results["mean_success"] = float(np.mean(successes)) if successes else 0.0
244
+ results["mean_return"] = float(np.mean(returns)) if returns else 0.0
245
+ finally:
246
+ env.shutdown()
247
+
248
+ (output_dir / "rollout_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
249
+ print(json.dumps(results, indent=2))
250
+
251
+
252
+ if __name__ == "__main__":
253
+ main()
code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py CHANGED
@@ -8,7 +8,14 @@ from typing import Any, Sequence
8
  import numpy as np
9
  import torch
10
  from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper
11
- from rlbench.action_modes.arm_action_modes import BimanualEndEffectorPoseViaPlanning
 
 
 
 
 
 
 
12
  from rlbench.action_modes.gripper_action_modes import BimanualDiscrete
13
  from rlbench.environment import Environment
14
 
@@ -31,6 +38,51 @@ from sim_rlbench.task_resolver import resolve_task_class
31
  from train.trainer import TrainerConfig, build_policy, planner_enabled, policy_supports_planning
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _policy_config_from_checkpoint(checkpoint: dict[str, Any]) -> PolicyConfig:
35
  cfg = checkpoint["policy_config"]
36
  return PolicyConfig(
@@ -48,6 +100,26 @@ def _trainer_config_from_checkpoint(checkpoint: dict[str, Any]) -> TrainerConfig
48
  return TrainerConfig(**checkpoint["trainer_config"])
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _episode_language_goal(descriptions: Sequence[str]) -> str:
52
  return str(descriptions[0]) if descriptions else ""
53
 
@@ -92,13 +164,23 @@ def _scaled_bimanual_delta(delta_action: np.ndarray, scale: float) -> np.ndarray
92
  return scaled
93
 
94
 
95
- def _step_bimanual_chunk(task: Any, obs: Any, delta_action: np.ndarray) -> tuple[Any, float, bool, int, int]:
 
 
 
 
 
96
  last_error: Exception | None = None
97
- for scale in (1.0, 0.5, 0.25, 0.1):
98
  try:
99
- env_action = absolute_action_from_delta(obs, _scaled_bimanual_delta(delta_action, scale), ignore_collisions=True)
 
 
 
 
 
100
  next_obs, reward, done = task.step(env_action)
101
- recovered_steps = 1 if scale < 1.0 else 0
102
  return next_obs, float(reward), bool(done), recovered_steps, 0
103
  except Exception as exc: # pragma: no cover - live RLBench failure path
104
  last_error = exc
@@ -131,6 +213,8 @@ def main() -> None:
131
  parser.add_argument("--reset-retries", type=int, default=20)
132
  parser.add_argument("--no-geometry", action="store_true")
133
  parser.add_argument("--compact-world-model", action="store_true")
 
 
134
  args = parser.parse_args()
135
 
136
  checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
@@ -138,7 +222,7 @@ def main() -> None:
138
  trainer_config = _trainer_config_from_checkpoint(checkpoint)
139
  device = torch.device("cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu")
140
  model = build_policy(policy_config, trainer_config).to(device)
141
- incompatible = model.load_state_dict(checkpoint["state_dict"], strict=False)
142
  allowed_missing = {
143
  key
144
  for key in incompatible.missing_keys
@@ -148,6 +232,8 @@ def main() -> None:
148
  or key.startswith("elastic_state_head.decoder.task_")
149
  or key.startswith("world_model.task_")
150
  or key.startswith("world_model.spatial_")
 
 
151
  }
152
  missing_other = sorted(set(incompatible.missing_keys) - allowed_missing)
153
  if missing_other or incompatible.unexpected_keys:
@@ -182,12 +268,24 @@ def main() -> None:
182
  "episode_length": args.episode_length,
183
  "resolution": args.resolution,
184
  "reset_retries": args.reset_retries,
 
 
185
  "cameras": list(camera_spec.cameras),
186
  "tasks": {},
187
  }
 
 
188
  if planning_note is not None:
189
  results["planning_note"] = planning_note
190
 
 
 
 
 
 
 
 
 
191
  for task_name in args.tasks:
192
  task_successes: list[float] = []
193
  task_returns: list[float] = []
@@ -195,8 +293,13 @@ def main() -> None:
195
  try:
196
  task_class = resolve_task_class(task_name)
197
  obs_config = build_obs_config(list(camera_spec.upstream_cameras), args.resolution)
 
 
 
 
 
198
  action_mode = BimanualMoveArmThenGripper(
199
- BimanualEndEffectorPoseViaPlanning(absolute_mode=True, frame="world", collision_checking=False),
200
  BimanualDiscrete(),
201
  )
202
  env = Environment(
@@ -323,7 +426,12 @@ def main() -> None:
323
  history_images.append(live_images)
324
  history_proprio.append(live_proprio)
325
  history_actions.append(step_action.astype(np.float32))
326
- obs, reward, done, recovered_steps, noop_fallbacks = _step_bimanual_chunk(task, obs, step_action)
 
 
 
 
 
327
  episode_recoveries += recovered_steps
328
  episode_noop_fallbacks += noop_fallbacks
329
  episode_trace["steps"].append(
@@ -368,14 +476,21 @@ def main() -> None:
368
  except Exception as exc:
369
  results["tasks"][task_name] = {"error": str(exc), "mean_success": 0.0, "mean_return": 0.0}
370
  finally:
 
 
 
 
 
 
 
 
 
 
371
  if env is not None:
372
  env.shutdown()
373
 
374
  task_scores = [task_data["mean_success"] for task_data in results["tasks"].values()]
375
  results["mean_success"] = float(np.mean(task_scores)) if task_scores else 0.0
376
-
377
- output_dir = Path(args.output_dir)
378
- output_dir.mkdir(parents=True, exist_ok=True)
379
  (output_dir / "rollout_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
380
  lines = [
381
  "# RLBench Rollout Eval",
 
8
  import numpy as np
9
  import torch
10
  from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper
11
+ from rlbench.action_modes.arm_action_modes import (
12
+ BimanualEndEffectorPoseViaPlanning,
13
+ EndEffectorPoseViaIK,
14
+ IKError,
15
+ InvalidActionError,
16
+ assert_action_shape,
17
+ assert_unit_quaternion,
18
+ )
19
  from rlbench.action_modes.gripper_action_modes import BimanualDiscrete
20
  from rlbench.environment import Environment
21
 
 
38
  from train.trainer import TrainerConfig, build_policy, planner_enabled, policy_supports_planning
39
 
40
 
41
+ class BimanualEndEffectorPoseViaIK(EndEffectorPoseViaIK):
42
+ def action(self, scene: Any, action: np.ndarray, ignore_collisions: Sequence[bool] | None = None) -> None:
43
+ assert_action_shape(action, (14,))
44
+ right_action = action[:7]
45
+ left_action = action[7:]
46
+ assert_unit_quaternion(right_action[3:])
47
+ assert_unit_quaternion(left_action[3:])
48
+
49
+ target_positions: list[np.ndarray] = []
50
+ for arm_action, arm in ((right_action, scene.robot.right_arm), (left_action, scene.robot.left_arm)):
51
+ try:
52
+ joint_positions = arm.solve_ik_via_jacobian(
53
+ arm_action[:3],
54
+ quaternion=arm_action[3:],
55
+ relative_to=None,
56
+ )
57
+ target_positions.append(np.asarray(joint_positions, dtype=np.float32))
58
+ arm.set_joint_target_positions(joint_positions)
59
+ except IKError as exc:
60
+ raise InvalidActionError(
61
+ "Could not perform bimanual IK via Jacobian; target pose is likely too far from the current pose."
62
+ ) from exc
63
+
64
+ done = False
65
+ prev_right = None
66
+ prev_left = None
67
+ while not done:
68
+ scene.step()
69
+ cur_right = np.asarray(scene.robot.right_arm.get_joint_positions(), dtype=np.float32)
70
+ cur_left = np.asarray(scene.robot.left_arm.get_joint_positions(), dtype=np.float32)
71
+ reached = np.allclose(cur_right, target_positions[0], atol=0.01) and np.allclose(cur_left, target_positions[1], atol=0.01)
72
+ not_moving = False
73
+ if prev_right is not None and prev_left is not None:
74
+ not_moving = np.allclose(cur_right, prev_right, atol=0.001) and np.allclose(cur_left, prev_left, atol=0.001)
75
+ prev_right = cur_right
76
+ prev_left = cur_left
77
+ done = reached or not_moving
78
+
79
+ def action_shape(self, scene: Any) -> tuple[int]:
80
+ return (14,)
81
+
82
+ def unimanual_action_shape(self, scene: Any) -> tuple[int]:
83
+ return (7,)
84
+
85
+
86
  def _policy_config_from_checkpoint(checkpoint: dict[str, Any]) -> PolicyConfig:
87
  cfg = checkpoint["policy_config"]
88
  return PolicyConfig(
 
100
  return TrainerConfig(**checkpoint["trainer_config"])
101
 
102
 
103
+ def _load_compatible_state_dict(
104
+ model: torch.nn.Module,
105
+ checkpoint_state: dict[str, Any],
106
+ ) -> tuple[Any, list[str]]:
107
+ model_state = model.state_dict()
108
+ compatible_state: dict[str, Any] = {}
109
+ skipped_shape_mismatches: list[str] = []
110
+ for key, value in checkpoint_state.items():
111
+ target = model_state.get(key)
112
+ if target is None:
113
+ compatible_state[key] = value
114
+ continue
115
+ if hasattr(value, "shape") and tuple(value.shape) != tuple(target.shape):
116
+ skipped_shape_mismatches.append(key)
117
+ continue
118
+ compatible_state[key] = value
119
+ incompatible = model.load_state_dict(compatible_state, strict=False)
120
+ return incompatible, skipped_shape_mismatches
121
+
122
+
123
  def _episode_language_goal(descriptions: Sequence[str]) -> str:
124
  return str(descriptions[0]) if descriptions else ""
125
 
 
164
  return scaled
165
 
166
 
167
+ def _step_bimanual_chunk(
168
+ task: Any,
169
+ obs: Any,
170
+ delta_action: np.ndarray,
171
+ delta_scale: float = 1.0,
172
+ ) -> tuple[Any, float, bool, int, int]:
173
  last_error: Exception | None = None
174
+ for scale in (1.0, 0.5, 0.25, 0.1, 0.05, 0.02, 0.01):
175
  try:
176
+ effective_scale = float(delta_scale) * float(scale)
177
+ env_action = absolute_action_from_delta(
178
+ obs,
179
+ _scaled_bimanual_delta(delta_action, effective_scale),
180
+ ignore_collisions=True,
181
+ )
182
  next_obs, reward, done = task.step(env_action)
183
+ recovered_steps = 1 if effective_scale < 1.0 else 0
184
  return next_obs, float(reward), bool(done), recovered_steps, 0
185
  except Exception as exc: # pragma: no cover - live RLBench failure path
186
  last_error = exc
 
213
  parser.add_argument("--reset-retries", type=int, default=20)
214
  parser.add_argument("--no-geometry", action="store_true")
215
  parser.add_argument("--compact-world-model", action="store_true")
216
+ parser.add_argument("--arm-mode", choices=("planning", "ik"), default="planning")
217
+ parser.add_argument("--delta-scale", type=float, default=1.0)
218
  args = parser.parse_args()
219
 
220
  checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
 
222
  trainer_config = _trainer_config_from_checkpoint(checkpoint)
223
  device = torch.device("cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu")
224
  model = build_policy(policy_config, trainer_config).to(device)
225
+ incompatible, skipped_shape_mismatches = _load_compatible_state_dict(model, checkpoint["state_dict"])
226
  allowed_missing = {
227
  key
228
  for key in incompatible.missing_keys
 
232
  or key.startswith("elastic_state_head.decoder.task_")
233
  or key.startswith("world_model.task_")
234
  or key.startswith("world_model.spatial_")
235
+ or key.startswith("decoder.proposal_score.")
236
+ or key.startswith("world_model.initial.")
237
  }
238
  missing_other = sorted(set(incompatible.missing_keys) - allowed_missing)
239
  if missing_other or incompatible.unexpected_keys:
 
268
  "episode_length": args.episode_length,
269
  "resolution": args.resolution,
270
  "reset_retries": args.reset_retries,
271
+ "arm_mode": args.arm_mode,
272
+ "delta_scale": args.delta_scale,
273
  "cameras": list(camera_spec.cameras),
274
  "tasks": {},
275
  }
276
+ if skipped_shape_mismatches:
277
+ results["skipped_shape_mismatches"] = skipped_shape_mismatches
278
  if planning_note is not None:
279
  results["planning_note"] = planning_note
280
 
281
+ output_dir = Path(args.output_dir)
282
+ output_dir.mkdir(parents=True, exist_ok=True)
283
+
284
+ def write_results(filename: str = "rollout_eval.partial.json") -> None:
285
+ task_scores = [task_data["mean_success"] for task_data in results["tasks"].values()]
286
+ results["mean_success"] = float(np.mean(task_scores)) if task_scores else 0.0
287
+ (output_dir / filename).write_text(json.dumps(results, indent=2), encoding="utf-8")
288
+
289
  for task_name in args.tasks:
290
  task_successes: list[float] = []
291
  task_returns: list[float] = []
 
293
  try:
294
  task_class = resolve_task_class(task_name)
295
  obs_config = build_obs_config(list(camera_spec.upstream_cameras), args.resolution)
296
+ arm_action_mode: Any
297
+ if args.arm_mode == "ik":
298
+ arm_action_mode = BimanualEndEffectorPoseViaIK(absolute_mode=True, frame="world", collision_checking=False)
299
+ else:
300
+ arm_action_mode = BimanualEndEffectorPoseViaPlanning(absolute_mode=True, frame="world", collision_checking=False)
301
  action_mode = BimanualMoveArmThenGripper(
302
+ arm_action_mode,
303
  BimanualDiscrete(),
304
  )
305
  env = Environment(
 
426
  history_images.append(live_images)
427
  history_proprio.append(live_proprio)
428
  history_actions.append(step_action.astype(np.float32))
429
+ obs, reward, done, recovered_steps, noop_fallbacks = _step_bimanual_chunk(
430
+ task,
431
+ obs,
432
+ step_action,
433
+ delta_scale=args.delta_scale,
434
+ )
435
  episode_recoveries += recovered_steps
436
  episode_noop_fallbacks += noop_fallbacks
437
  episode_trace["steps"].append(
 
476
  except Exception as exc:
477
  results["tasks"][task_name] = {"error": str(exc), "mean_success": 0.0, "mean_return": 0.0}
478
  finally:
479
+ write_results()
480
+ task_result = results["tasks"][task_name]
481
+ if "error" in task_result:
482
+ print(f"[task] {task_name}: error={task_result['error']}", flush=True)
483
+ else:
484
+ print(
485
+ f"[task] {task_name}: mean_success={task_result['mean_success']:.3f} "
486
+ f"mean_return={task_result['mean_return']:.3f}",
487
+ flush=True,
488
+ )
489
  if env is not None:
490
  env.shutdown()
491
 
492
  task_scores = [task_data["mean_success"] for task_data in results["tasks"].values()]
493
  results["mean_success"] = float(np.mean(task_scores)) if task_scores else 0.0
 
 
 
494
  (output_dir / "rollout_eval.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
495
  lines = [
496
  "# RLBench Rollout Eval",
code/reveal_vla_bimanual/sim_rlbench/dataset.py CHANGED
@@ -287,6 +287,7 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
287
  proprio_dim: int = 32,
288
  cameras: Sequence[str] = THREE_CAMERAS,
289
  history_steps: int = 2,
 
290
  max_samples: int | None = None,
291
  ) -> None:
292
  self.dataset_root = Path(dataset_root)
@@ -297,6 +298,9 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
297
  self.proprio_dim = int(proprio_dim)
298
  self.cameras = tuple(cameras)
299
  self.history_steps = int(history_steps)
 
 
 
300
  self._episodes: dict[str, EpisodeRecord] = {}
301
  self._samples: list[SampleRecord] = []
302
 
@@ -362,6 +366,13 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
362
  actions.append(action)
363
  return torch.from_numpy(np.stack(actions, axis=0))
364
 
 
 
 
 
 
 
 
365
  def _history_rgb_stack(self, episode_dir: Path, step_index: int) -> torch.Tensor:
366
  if self.history_steps <= 0:
367
  return torch.zeros((0, len(self.cameras), 3, self.resolution, self.resolution), dtype=torch.float32)
@@ -423,6 +434,7 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
423
  ),
424
  "texts": episode.language_goal,
425
  "action_chunk": self._action_chunk(observations, sample.step_index),
 
426
  "task": sample.task,
427
  "episode_index": sample.episode_index,
428
  "step_index": sample.step_index,
@@ -439,4 +451,5 @@ class RLBenchOfflineChunkDataset(Dataset[dict[str, Any]]):
439
  "chunk_size": self.chunk_size,
440
  "proprio_dim": self.proprio_dim,
441
  "history_steps": self.history_steps,
 
442
  }
 
287
  proprio_dim: int = 32,
288
  cameras: Sequence[str] = THREE_CAMERAS,
289
  history_steps: int = 2,
290
+ supervise_action_steps: int | None = None,
291
  max_samples: int | None = None,
292
  ) -> None:
293
  self.dataset_root = Path(dataset_root)
 
298
  self.proprio_dim = int(proprio_dim)
299
  self.cameras = tuple(cameras)
300
  self.history_steps = int(history_steps)
301
+ self.supervise_action_steps = (
302
+ None if supervise_action_steps is None else max(1, min(int(supervise_action_steps), self.chunk_size))
303
+ )
304
  self._episodes: dict[str, EpisodeRecord] = {}
305
  self._samples: list[SampleRecord] = []
306
 
 
366
  actions.append(action)
367
  return torch.from_numpy(np.stack(actions, axis=0))
368
 
369
+ def _action_mask(self) -> torch.Tensor:
370
+ mask = torch.ones((self.chunk_size,), dtype=torch.float32)
371
+ if self.supervise_action_steps is None:
372
+ return mask
373
+ mask[self.supervise_action_steps :] = 0.0
374
+ return mask
375
+
376
  def _history_rgb_stack(self, episode_dir: Path, step_index: int) -> torch.Tensor:
377
  if self.history_steps <= 0:
378
  return torch.zeros((0, len(self.cameras), 3, self.resolution, self.resolution), dtype=torch.float32)
 
434
  ),
435
  "texts": episode.language_goal,
436
  "action_chunk": self._action_chunk(observations, sample.step_index),
437
+ "action_mask": self._action_mask(),
438
  "task": sample.task,
439
  "episode_index": sample.episode_index,
440
  "step_index": sample.step_index,
 
451
  "chunk_size": self.chunk_size,
452
  "proprio_dim": self.proprio_dim,
453
  "history_steps": self.history_steps,
454
+ "supervise_action_steps": self.supervise_action_steps,
455
  }
code/reveal_vla_bimanual/sim_rlbench/dataset_download.py CHANGED
@@ -85,12 +85,14 @@ def main() -> None:
85
  archive_path = archive_root / filename
86
  expected_sha = checksums[filename]
87
  url = f"{base_url}/{filename}"
 
 
88
 
89
  print(f"[plan] {filename}", flush=True)
90
  print(f" url={url}", flush=True)
91
  print(f" archive={archive_path}", flush=True)
92
  if args.extract:
93
- print(f" extract_root={extract_root}", flush=True)
94
 
95
  if args.dry_run:
96
  continue
@@ -112,8 +114,9 @@ def main() -> None:
112
  print(f"[done] downloaded {filename}", flush=True)
113
 
114
  if args.extract:
 
115
  subprocess.run(
116
- ["unsquashfs", "-f", "-d", str(extract_root), str(archive_path)],
117
  check=True,
118
  )
119
  print(f"[done] extracted {filename}", flush=True)
 
85
  archive_path = archive_root / filename
86
  expected_sha = checksums[filename]
87
  url = f"{base_url}/{filename}"
88
+ task_name = filename.split(".", 1)[0]
89
+ task_extract_root = extract_root / task_name
90
 
91
  print(f"[plan] {filename}", flush=True)
92
  print(f" url={url}", flush=True)
93
  print(f" archive={archive_path}", flush=True)
94
  if args.extract:
95
+ print(f" extract_root={task_extract_root}", flush=True)
96
 
97
  if args.dry_run:
98
  continue
 
114
  print(f"[done] downloaded {filename}", flush=True)
115
 
116
  if args.extract:
117
+ task_extract_root.mkdir(parents=True, exist_ok=True)
118
  subprocess.run(
119
+ ["unsquashfs", "-f", "-q", "-no-progress", "-d", str(task_extract_root), str(archive_path)],
120
  check=True,
121
  )
122
  print(f"[done] extracted {filename}", flush=True)