| | """ |
| | Step-by-step planning collection tool. |
| | |
| | All step-by-step collection interfaces uniformly return a 5-tuple: |
| | (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch) |
| | |
| | - obs_batch/info_batch: dict[str, list], keys from single step dictionary. |
| | - reward_batch: torch.float32, shape [N]. |
| | - terminated_batch/truncated_batch: torch.bool, shape [N]. |
| | """ |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def _collapse_singleton_lists(value, key=None): |
| | """Recursively unwrap singleton lists while preserving non-singleton lists.""" |
| | if key == "task_goal": |
| | return value |
| | while isinstance(value, list) and len(value) == 1: |
| | value = value[0] |
| | return value |
| |
|
| |
|
| | def _to_scalar(value): |
| | """Convert scalar or tensor value to Python scalar.""" |
| | if isinstance(value, torch.Tensor): |
| | if value.numel() == 0: |
| | return 0 |
| | return value.reshape(-1)[0].item() |
| | return value |
| |
|
| |
|
| | def _snapshot_value(value): |
| | """Snapshot values that might reuse underlying memory to avoid overwriting previous frames in subsequent steps.""" |
| | if isinstance(value, torch.Tensor): |
| | return value.detach().clone() |
| | if isinstance(value, np.ndarray): |
| | return value.copy() |
| | if isinstance(value, dict): |
| | return {k: _snapshot_value(v) for k, v in value.items()} |
| | if isinstance(value, list): |
| | return [_snapshot_value(v) for v in value] |
| | if isinstance(value, tuple): |
| | return tuple(_snapshot_value(v) for v in value) |
| | return value |
| |
|
| |
|
| | def _snapshot_step(out): |
| | """Deep copy snapshot of single step output.""" |
| | if not (isinstance(out, tuple) and len(out) == 5): |
| | return out |
| | obs, reward, terminated, truncated, info = out |
| | return ( |
| | _snapshot_value(obs), |
| | _snapshot_value(reward), |
| | _snapshot_value(terminated), |
| | _snapshot_value(truncated), |
| | _snapshot_value(info), |
| | ) |
| |
|
| |
|
| | def _is_columnar_dict(batch_dict, n): |
| | if not isinstance(batch_dict, dict): |
| | return False |
| | for value in batch_dict.values(): |
| | if not isinstance(value, list): |
| | return False |
| | if len(value) != n: |
| | return False |
| | return True |
| |
|
| |
|
| | def _output_to_steps(out): |
| | """ |
| | Normalize step output to "list of raw step tuples". |
| | Supports both single step tuple and unified batch tuple input formats. |
| | """ |
| | if isinstance(out, tuple) and len(out) == 5: |
| | obs_part, reward_part, terminated_part, truncated_part, info_part = out |
| | if ( |
| | isinstance(reward_part, torch.Tensor) |
| | and isinstance(terminated_part, torch.Tensor) |
| | and isinstance(truncated_part, torch.Tensor) |
| | and reward_part.ndim == 1 |
| | and terminated_part.ndim == 1 |
| | and truncated_part.ndim == 1 |
| | ): |
| | n = int(reward_part.numel()) |
| | if ( |
| | terminated_part.numel() == n |
| | and truncated_part.numel() == n |
| | and _is_columnar_dict(obs_part, n) |
| | and _is_columnar_dict(info_part, n) |
| | ): |
| | steps = [] |
| | obs_keys = list(obs_part.keys()) |
| | info_keys = list(info_part.keys()) |
| | for idx in range(n): |
| | obs = {k: _snapshot_value(obs_part[k][idx]) for k in obs_keys} |
| | info = {k: _snapshot_value(info_part[k][idx]) for k in info_keys} |
| | steps.append( |
| | ( |
| | obs, |
| | _snapshot_value(reward_part[idx]), |
| | _snapshot_value(terminated_part[idx]), |
| | _snapshot_value(truncated_part[idx]), |
| | info, |
| | ) |
| | ) |
| | return steps |
| | return [_snapshot_step(out)] |
| |
|
| |
|
| | def _dicts_to_columnar_dict(dict_steps): |
| | """ |
| | Convert step dictionaries to dict[str, list], filling missing keys with None. |
| | """ |
| | n = len(dict_steps) |
| | out = {} |
| | for idx, item in enumerate(dict_steps): |
| | current = item if isinstance(item, dict) else {} |
| | for key in current: |
| | if key not in out: |
| | out[key] = [None] * idx |
| | for key in out: |
| | out[key].append(_collapse_singleton_lists(current.get(key, None), key=key)) |
| | for key in out: |
| | if len(out[key]) < n: |
| | out[key].extend([None] * (n - len(out[key]))) |
| | return out |
| |
|
| |
|
| | def empty_step_batch(): |
| | """Return an empty batch following the unified contract.""" |
| | return ( |
| | {}, |
| | torch.empty(0, dtype=torch.float32), |
| | torch.empty(0, dtype=torch.bool), |
| | torch.empty(0, dtype=torch.bool), |
| | {}, |
| | ) |
| |
|
| |
|
| | def to_step_batch(collected_steps): |
| | """ |
| | Convert collected step tuples to unified batch output. |
| | collected_steps: [(obs, reward, terminated, truncated, info), ...] |
| | """ |
| | if not collected_steps: |
| | return empty_step_batch() |
| |
|
| | obs_steps = [x[0] for x in collected_steps] |
| | reward_steps = [_to_scalar(x[1]) for x in collected_steps] |
| | terminated_steps = [bool(_to_scalar(x[2])) for x in collected_steps] |
| | truncated_steps = [bool(_to_scalar(x[3])) for x in collected_steps] |
| | info_steps = [x[4] for x in collected_steps] |
| |
|
| | obs_batch = _dicts_to_columnar_dict(obs_steps) |
| | info_batch = _dicts_to_columnar_dict(info_steps) |
| | reward_batch = torch.tensor(reward_steps, dtype=torch.float32) |
| | terminated_batch = torch.tensor(terminated_steps, dtype=torch.bool) |
| | truncated_batch = torch.tensor(truncated_steps, dtype=torch.bool) |
| | return (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch) |
| |
|
| |
|
| | def concat_step_batches(batches): |
| | """ |
| | Concatenate multiple unified batches into one unified batch. |
| | """ |
| | valid = [] |
| | for batch in batches: |
| | if batch is None: |
| | continue |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| | if reward_batch.numel() == 0: |
| | continue |
| | valid.append((obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)) |
| | if not valid: |
| | return empty_step_batch() |
| |
|
| | obs_out = {} |
| | info_out = {} |
| | reward_out = [] |
| | terminated_out = [] |
| | truncated_out = [] |
| | n_total = 0 |
| |
|
| | for obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch in valid: |
| | n = int(reward_batch.numel()) |
| | for key in obs_batch: |
| | if key not in obs_out: |
| | obs_out[key] = [None] * n_total |
| | for key in obs_out: |
| | values = obs_batch.get(key, None) |
| | if values is None: |
| | obs_out[key].extend([None] * n) |
| | else: |
| | obs_out[key].extend(_collapse_singleton_lists(v, key=key) for v in values) |
| |
|
| | for key in info_batch: |
| | if key not in info_out: |
| | info_out[key] = [None] * n_total |
| | for key in info_out: |
| | values = info_batch.get(key, None) |
| | if values is None: |
| | info_out[key].extend([None] * n) |
| | else: |
| | info_out[key].extend(_collapse_singleton_lists(v, key=key) for v in values) |
| |
|
| | reward_out.append(reward_batch.reshape(-1).to(torch.float32)) |
| | terminated_out.append(terminated_batch.reshape(-1).to(torch.bool)) |
| | truncated_out.append(truncated_batch.reshape(-1).to(torch.bool)) |
| | n_total += n |
| |
|
| | return ( |
| | obs_out, |
| | torch.cat(reward_out, dim=0) if reward_out else torch.empty(0, dtype=torch.float32), |
| | torch.cat(terminated_out, dim=0) if terminated_out else torch.empty(0, dtype=torch.bool), |
| | torch.cat(truncated_out, dim=0) if truncated_out else torch.empty(0, dtype=torch.bool), |
| | info_out, |
| | ) |
| |
|
| |
|
| | def _collect_dense_steps(planner, fn): |
| | """ |
| | Intercept planner.env.step when running fn(), collecting raw step tuples. |
| | If fn returns -1, return -1; otherwise return collected result list. |
| | """ |
| | collected = [] |
| | original_step = planner.env.step |
| |
|
| | def _step(action): |
| | out = original_step(action) |
| | collected.extend(_output_to_steps(out)) |
| | return out |
| |
|
| | planner.env.step = _step |
| | try: |
| | result = fn() |
| | if result == -1: |
| | return -1 |
| | return collected |
| | finally: |
| | planner.env.step = original_step |
| |
|
| |
|
| | def _run_with_dense_collection(planner, fn): |
| | """ |
| | Run fn() and return unified batch; return -1 if fn returns -1. |
| | """ |
| | collected = _collect_dense_steps(planner, fn) |
| | if collected == -1: |
| | return -1 |
| | return to_step_batch(collected) |
| |
|
| |
|
| | def move_to_pose_with_RRTStar(planner, pose): |
| | """ |
| | Call planner.move_to_pose_with_RRTStar(pose) and return unified batch. |
| | Return -1 on planning failure. |
| | """ |
| | return _run_with_dense_collection( |
| | planner, lambda: planner.move_to_pose_with_RRTStar(pose) |
| | ) |
| |
|
| |
|
| | def move_to_pose_with_screw(planner, pose): |
| | """ |
| | Call planner.move_to_pose_with_screw(pose) and return unified batch. |
| | Return -1 on planning failure. |
| | """ |
| | return _run_with_dense_collection( |
| | planner, lambda: planner.move_to_pose_with_screw(pose) |
| | ) |
| |
|
| |
|
| | def close_gripper(planner): |
| | """ |
| | Call planner.close_gripper() and return unified batch. |
| | Return -1 on failure. |
| | """ |
| | return _run_with_dense_collection(planner, lambda: planner.close_gripper()) |
| |
|
| |
|
| | def open_gripper(planner): |
| | """ |
| | Call planner.open_gripper() and return unified batch. |
| | Return -1 on failure. |
| | """ |
| | return _run_with_dense_collection(planner, lambda: planner.open_gripper()) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|