File size: 9,142 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 | from __future__ import annotations
import json
import shutil
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
import gymnasium as gym
from tests._shared.repo_paths import ensure_src_on_path
ensure_src_on_path(__file__)
from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout # noqa: E402
from robomme.robomme_env import * # noqa: F401,F403,E402
from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError # noqa: E402
from robomme.robomme_env.utils.planner_fail_safe import ( # noqa: E402
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
ScrewPlanFailure,
)
DATASET_SCREW_MAX_ATTEMPTS = 3
DATASET_RRT_MAX_ATTEMPTS = 3
MAX_SEED_ATTEMPTS = 30
@dataclass(frozen=True)
class DatasetCase:
env_id: str
episode: int
base_seed: int
difficulty: Optional[str]
save_video: bool
mode_tag: str
def cache_key(self) -> str:
difficulty = self.difficulty if self.difficulty else "none"
return (
f"{self.env_id}_ep{self.episode}_{difficulty}_"
f"{self.base_seed}_{int(self.save_video)}_{self.mode_tag}"
)
@dataclass(frozen=True)
class GeneratedDataset:
case: DatasetCase
work_dir: Path
raw_h5_path: Path
resolver_dataset_dir: Path
resolver_h5_path: Path
used_seed: int
def _tensor_to_bool(value) -> bool:
if value is None:
return False
if isinstance(value, torch.Tensor):
return bool(value.detach().cpu().bool().item())
if isinstance(value, np.ndarray):
return bool(np.any(value))
return bool(value)
def _patch_planner_screw_to_rrt(planner) -> None:
original_screw = planner.move_to_pose_with_screw
original_rrt = planner.move_to_pose_with_RRTStar
def _move_screw_then_rrt(*args, **kwargs):
for _ in range(DATASET_SCREW_MAX_ATTEMPTS):
try:
result = original_screw(*args, **kwargs)
except ScrewPlanFailure:
continue
if isinstance(result, int) and result == -1:
continue
return result
for _ in range(DATASET_RRT_MAX_ATTEMPTS):
try:
result = original_rrt(*args, **kwargs)
except Exception:
continue
if isinstance(result, int) and result == -1:
continue
return result
return -1
planner.move_to_pose_with_screw = _move_screw_then_rrt
def _run_one_episode(
case: DatasetCase,
seed: int,
output_dir: Path,
) -> bool:
env_kwargs = dict(
obs_mode="rgb+depth+segmentation",
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="dense",
seed=seed,
difficulty=case.difficulty,
)
if case.episode <= 5:
env_kwargs["robomme_failure_recovery"] = True
env_kwargs["robomme_failure_recovery_mode"] = "z" if case.episode <= 2 else "xy"
env = gym.make(case.env_id, **env_kwargs)
env = RobommeRecordWrapper(
env,
dataset=str(output_dir),
env_id=case.env_id,
episode=case.episode,
seed=seed,
save_video=case.save_video,
)
episode_successful = False
try:
env.reset()
is_stick = case.env_id in ("PatternLock", "RouteStick")
if is_stick:
planner = FailAwarePandaStickMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
planner = FailAwarePandaArmMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
)
_patch_planner_screw_to_rrt(planner)
tasks = list(getattr(env.unwrapped, "task_list", []) or [])
for task_entry in tasks:
solve_callable = task_entry.get("solve")
if not callable(solve_callable):
continue
env.unwrapped.evaluate(solve_complete_eval=True)
screw_failed = False
try:
solve_result = solve_callable(env, planner)
if isinstance(solve_result, int) and solve_result == -1:
screw_failed = True
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except ScrewPlanFailure:
screw_failed = True
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except FailsafeTimeout:
break
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
fail_flag = evaluation.get("fail", False)
success_flag = evaluation.get("success", False)
if _tensor_to_bool(success_flag):
episode_successful = True
break
if screw_failed or _tensor_to_bool(fail_flag):
break
else:
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
episode_successful = _tensor_to_bool(evaluation.get("success", False))
episode_successful = episode_successful or _tensor_to_bool(
getattr(env, "episode_success", False)
)
except SceneGenerationError:
episode_successful = False
finally:
try:
env.close()
except Exception:
pass
return episode_successful
def _run_episode_with_retry(case: DatasetCase, output_dir: Path) -> tuple[Path, int]:
for attempt in range(MAX_SEED_ATTEMPTS):
seed = case.base_seed + attempt
try:
success = _run_one_episode(case=case, seed=seed, output_dir=output_dir)
except Exception:
continue
if not success:
continue
h5_path = output_dir / "hdf5_files" / f"{case.env_id}_ep{case.episode}_seed{seed}.h5"
if not h5_path.exists():
raise FileNotFoundError(f"Missing expected HDF5: {h5_path}")
return h5_path, seed
raise RuntimeError(
f"[{case.env_id}] Failed to generate successful record in {MAX_SEED_ATTEMPTS} attempts."
)
def _write_meta(meta_path: Path, payload: dict) -> None:
meta_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _read_meta(meta_path: Path) -> dict:
return json.loads(meta_path.read_text(encoding="utf-8"))
def generate_dataset_case(case: DatasetCase, cache_root: Path) -> GeneratedDataset:
case_dir = cache_root / case.cache_key()
work_dir = case_dir / "work"
resolver_dataset_dir = case_dir / "resolver_dataset"
resolver_h5_path = resolver_dataset_dir / f"record_dataset_{case.env_id}.h5"
meta_path = case_dir / "meta.json"
if meta_path.exists():
meta = _read_meta(meta_path)
raw_h5_path = Path(meta["raw_h5_path"])
if raw_h5_path.exists() and resolver_h5_path.exists():
return GeneratedDataset(
case=case,
work_dir=work_dir,
raw_h5_path=raw_h5_path,
resolver_dataset_dir=resolver_dataset_dir,
resolver_h5_path=resolver_h5_path,
used_seed=int(meta["used_seed"]),
)
case_dir.mkdir(parents=True, exist_ok=True)
work_dir.mkdir(parents=True, exist_ok=True)
resolver_dataset_dir.mkdir(parents=True, exist_ok=True)
raw_h5_path, used_seed = _run_episode_with_retry(case=case, output_dir=work_dir)
shutil.copy2(raw_h5_path, resolver_h5_path)
payload = {
"case": asdict(case),
"used_seed": used_seed,
"raw_h5_path": str(raw_h5_path),
"resolver_h5_path": str(resolver_h5_path),
}
_write_meta(meta_path, payload)
return GeneratedDataset(
case=case,
work_dir=work_dir,
raw_h5_path=raw_h5_path,
resolver_dataset_dir=resolver_dataset_dir,
resolver_h5_path=resolver_h5_path,
used_seed=used_seed,
)
class DatasetFactoryCache:
def __init__(self, cache_root: Path):
self.cache_root = cache_root
self._memo: dict[str, GeneratedDataset] = {}
def get(self, case: DatasetCase) -> GeneratedDataset:
key = case.cache_key()
cached = self._memo.get(key)
if cached is not None:
return cached
generated = generate_dataset_case(case, self.cache_root)
self._memo[key] = generated
return generated
DatasetFactory = Callable[[DatasetCase], GeneratedDataset]
|