RoboMME / src /robomme /robomme_env /utils /planner_denseStep.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
"""
Step-by-step planning collection tool.
All step-by-step collection interfaces uniformly return a 5-tuple:
(obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)
- obs_batch/info_batch: dict[str, list], keys from single step dictionary.
- reward_batch: torch.float32, shape [N].
- terminated_batch/truncated_batch: torch.bool, shape [N].
"""
import numpy as np
import torch
def _collapse_singleton_lists(value, key=None):
"""Recursively unwrap singleton lists while preserving non-singleton lists."""
if key == "task_goal":
return value
while isinstance(value, list) and len(value) == 1:
value = value[0]
return value
def _to_scalar(value):
"""Convert scalar or tensor value to Python scalar."""
if isinstance(value, torch.Tensor):
if value.numel() == 0:
return 0
return value.reshape(-1)[0].item()
return value
def _snapshot_value(value):
"""Snapshot values that might reuse underlying memory to avoid overwriting previous frames in subsequent steps."""
if isinstance(value, torch.Tensor):
return value.detach().clone()
if isinstance(value, np.ndarray):
return value.copy()
if isinstance(value, dict):
return {k: _snapshot_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_snapshot_value(v) for v in value]
if isinstance(value, tuple):
return tuple(_snapshot_value(v) for v in value)
return value
def _snapshot_step(out):
"""Deep copy snapshot of single step output."""
if not (isinstance(out, tuple) and len(out) == 5):
return out
obs, reward, terminated, truncated, info = out
return (
_snapshot_value(obs),
_snapshot_value(reward),
_snapshot_value(terminated),
_snapshot_value(truncated),
_snapshot_value(info),
)
def _is_columnar_dict(batch_dict, n):
if not isinstance(batch_dict, dict):
return False
for value in batch_dict.values():
if not isinstance(value, list):
return False
if len(value) != n:
return False
return True
def _output_to_steps(out):
"""
Normalize step output to "list of raw step tuples".
Supports both single step tuple and unified batch tuple input formats.
"""
if isinstance(out, tuple) and len(out) == 5:
obs_part, reward_part, terminated_part, truncated_part, info_part = out
if (
isinstance(reward_part, torch.Tensor)
and isinstance(terminated_part, torch.Tensor)
and isinstance(truncated_part, torch.Tensor)
and reward_part.ndim == 1
and terminated_part.ndim == 1
and truncated_part.ndim == 1
):
n = int(reward_part.numel())
if (
terminated_part.numel() == n
and truncated_part.numel() == n
and _is_columnar_dict(obs_part, n)
and _is_columnar_dict(info_part, n)
):
steps = []
obs_keys = list(obs_part.keys())
info_keys = list(info_part.keys())
for idx in range(n):
obs = {k: _snapshot_value(obs_part[k][idx]) for k in obs_keys}
info = {k: _snapshot_value(info_part[k][idx]) for k in info_keys}
steps.append(
(
obs,
_snapshot_value(reward_part[idx]),
_snapshot_value(terminated_part[idx]),
_snapshot_value(truncated_part[idx]),
info,
)
)
return steps
return [_snapshot_step(out)]
def _dicts_to_columnar_dict(dict_steps):
"""
Convert step dictionaries to dict[str, list], filling missing keys with None.
"""
n = len(dict_steps)
out = {}
for idx, item in enumerate(dict_steps):
current = item if isinstance(item, dict) else {}
for key in current:
if key not in out:
out[key] = [None] * idx
for key in out:
out[key].append(_collapse_singleton_lists(current.get(key, None), key=key))
for key in out:
if len(out[key]) < n:
out[key].extend([None] * (n - len(out[key])))
return out
def empty_step_batch():
"""Return an empty batch following the unified contract."""
return (
{},
torch.empty(0, dtype=torch.float32),
torch.empty(0, dtype=torch.bool),
torch.empty(0, dtype=torch.bool),
{},
)
def to_step_batch(collected_steps):
"""
Convert collected step tuples to unified batch output.
collected_steps: [(obs, reward, terminated, truncated, info), ...]
"""
if not collected_steps:
return empty_step_batch()
obs_steps = [x[0] for x in collected_steps]
reward_steps = [_to_scalar(x[1]) for x in collected_steps]
terminated_steps = [bool(_to_scalar(x[2])) for x in collected_steps]
truncated_steps = [bool(_to_scalar(x[3])) for x in collected_steps]
info_steps = [x[4] for x in collected_steps]
obs_batch = _dicts_to_columnar_dict(obs_steps)
info_batch = _dicts_to_columnar_dict(info_steps)
reward_batch = torch.tensor(reward_steps, dtype=torch.float32)
terminated_batch = torch.tensor(terminated_steps, dtype=torch.bool)
truncated_batch = torch.tensor(truncated_steps, dtype=torch.bool)
return (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch)
def concat_step_batches(batches):
"""
Concatenate multiple unified batches into one unified batch.
"""
valid = []
for batch in batches:
if batch is None:
continue
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
if reward_batch.numel() == 0:
continue
valid.append((obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch))
if not valid:
return empty_step_batch()
obs_out = {}
info_out = {}
reward_out = []
terminated_out = []
truncated_out = []
n_total = 0
for obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch in valid:
n = int(reward_batch.numel())
for key in obs_batch:
if key not in obs_out:
obs_out[key] = [None] * n_total
for key in obs_out:
values = obs_batch.get(key, None)
if values is None:
obs_out[key].extend([None] * n)
else:
obs_out[key].extend(_collapse_singleton_lists(v, key=key) for v in values)
for key in info_batch:
if key not in info_out:
info_out[key] = [None] * n_total
for key in info_out:
values = info_batch.get(key, None)
if values is None:
info_out[key].extend([None] * n)
else:
info_out[key].extend(_collapse_singleton_lists(v, key=key) for v in values)
reward_out.append(reward_batch.reshape(-1).to(torch.float32))
terminated_out.append(terminated_batch.reshape(-1).to(torch.bool))
truncated_out.append(truncated_batch.reshape(-1).to(torch.bool))
n_total += n
return (
obs_out,
torch.cat(reward_out, dim=0) if reward_out else torch.empty(0, dtype=torch.float32),
torch.cat(terminated_out, dim=0) if terminated_out else torch.empty(0, dtype=torch.bool),
torch.cat(truncated_out, dim=0) if truncated_out else torch.empty(0, dtype=torch.bool),
info_out,
)
def _collect_dense_steps(planner, fn):
"""
Intercept planner.env.step when running fn(), collecting raw step tuples.
If fn returns -1, return -1; otherwise return collected result list.
"""
collected = []
original_step = planner.env.step
def _step(action):
out = original_step(action)
collected.extend(_output_to_steps(out))
return out
planner.env.step = _step
try:
result = fn()
if result == -1:
return -1
return collected
finally:
planner.env.step = original_step
def _run_with_dense_collection(planner, fn):
"""
Run fn() and return unified batch; return -1 if fn returns -1.
"""
collected = _collect_dense_steps(planner, fn)
if collected == -1:
return -1
return to_step_batch(collected)
def move_to_pose_with_RRTStar(planner, pose):
"""
Call planner.move_to_pose_with_RRTStar(pose) and return unified batch.
Return -1 on planning failure.
"""
return _run_with_dense_collection(
planner, lambda: planner.move_to_pose_with_RRTStar(pose)
)
def move_to_pose_with_screw(planner, pose):
"""
Call planner.move_to_pose_with_screw(pose) and return unified batch.
Return -1 on planning failure.
"""
return _run_with_dense_collection(
planner, lambda: planner.move_to_pose_with_screw(pose)
)
def close_gripper(planner):
"""
Call planner.close_gripper() and return unified batch.
Return -1 on failure.
"""
return _run_with_dense_collection(planner, lambda: planner.close_gripper())
def open_gripper(planner):
"""
Call planner.open_gripper() and return unified batch.
Return -1 on failure.
"""
return _run_with_dense_collection(planner, lambda: planner.open_gripper())
# ---- Call Relationships ----
#
# _collect_dense_steps:
# - DemonstrationWrapper.get_demonstration_trajectory()
# Wrap entire solve_callable, monkey-patch planner.env.step to collect all underlying steps
#
# _run_with_dense_collection:
# - OraclePlannerDemonstrationWrapper
# Wrap solve() in solve_options, collect all underlying steps and directly return unified batch
#
# move_to_pose_with_RRTStar:
# - Execute single step move in MultiStepDemonstrationWrapper
# (MultiStepDemonstrationWrapper.py line 106)
#
# move_to_pose_with_screw:
# - Currently no external call, reserved as symmetric API to move_to_pose_with_RRTStar
#
# close_gripper:
# - Execute gripper close in MultiStepDemonstrationWrapper
# (MultiStepDemonstrationWrapper.py line 112)
#
# open_gripper:
# - Execute gripper open in MultiStepDemonstrationWrapper
# (MultiStepDemonstrationWrapper.py line 121)