| from __future__ import annotations |
|
|
| import json |
| import shutil |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Callable, Optional |
|
|
| import numpy as np |
| import torch |
| import gymnasium as gym |
|
|
| from tests._shared.repo_paths import ensure_src_on_path |
|
|
| ensure_src_on_path(__file__) |
|
|
| from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout |
| from robomme.robomme_env import * |
| from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError |
| from robomme.robomme_env.utils.planner_fail_safe import ( |
| FailAwarePandaArmMotionPlanningSolver, |
| FailAwarePandaStickMotionPlanningSolver, |
| ScrewPlanFailure, |
| ) |
|
|
|
|
| DATASET_SCREW_MAX_ATTEMPTS = 3 |
| DATASET_RRT_MAX_ATTEMPTS = 3 |
| MAX_SEED_ATTEMPTS = 30 |
|
|
|
|
| @dataclass(frozen=True) |
| class DatasetCase: |
| env_id: str |
| episode: int |
| base_seed: int |
| difficulty: Optional[str] |
| save_video: bool |
| mode_tag: str |
|
|
| def cache_key(self) -> str: |
| difficulty = self.difficulty if self.difficulty else "none" |
| return ( |
| f"{self.env_id}_ep{self.episode}_{difficulty}_" |
| f"{self.base_seed}_{int(self.save_video)}_{self.mode_tag}" |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class GeneratedDataset: |
| case: DatasetCase |
| work_dir: Path |
| raw_h5_path: Path |
| resolver_dataset_dir: Path |
| resolver_h5_path: Path |
| used_seed: int |
|
|
|
|
| def _tensor_to_bool(value) -> bool: |
| if value is None: |
| return False |
| if isinstance(value, torch.Tensor): |
| return bool(value.detach().cpu().bool().item()) |
| if isinstance(value, np.ndarray): |
| return bool(np.any(value)) |
| return bool(value) |
|
|
|
|
| def _patch_planner_screw_to_rrt(planner) -> None: |
| original_screw = planner.move_to_pose_with_screw |
| original_rrt = planner.move_to_pose_with_RRTStar |
|
|
| def _move_screw_then_rrt(*args, **kwargs): |
| for _ in range(DATASET_SCREW_MAX_ATTEMPTS): |
| try: |
| result = original_screw(*args, **kwargs) |
| except ScrewPlanFailure: |
| continue |
| if isinstance(result, int) and result == -1: |
| continue |
| return result |
|
|
| for _ in range(DATASET_RRT_MAX_ATTEMPTS): |
| try: |
| result = original_rrt(*args, **kwargs) |
| except Exception: |
| continue |
| if isinstance(result, int) and result == -1: |
| continue |
| return result |
| return -1 |
|
|
| planner.move_to_pose_with_screw = _move_screw_then_rrt |
|
|
|
|
| def _run_one_episode( |
| case: DatasetCase, |
| seed: int, |
| output_dir: Path, |
| ) -> bool: |
| env_kwargs = dict( |
| obs_mode="rgb+depth+segmentation", |
| control_mode="pd_joint_pos", |
| render_mode="rgb_array", |
| reward_mode="dense", |
| seed=seed, |
| difficulty=case.difficulty, |
| ) |
| if case.episode <= 5: |
| env_kwargs["robomme_failure_recovery"] = True |
| env_kwargs["robomme_failure_recovery_mode"] = "z" if case.episode <= 2 else "xy" |
|
|
| env = gym.make(case.env_id, **env_kwargs) |
| env = RobommeRecordWrapper( |
| env, |
| dataset=str(output_dir), |
| env_id=case.env_id, |
| episode=case.episode, |
| seed=seed, |
| save_video=case.save_video, |
| ) |
|
|
| episode_successful = False |
| try: |
| env.reset() |
| is_stick = case.env_id in ("PatternLock", "RouteStick") |
| if is_stick: |
| planner = FailAwarePandaStickMotionPlanningSolver( |
| env, |
| debug=False, |
| vis=False, |
| base_pose=env.unwrapped.agent.robot.pose, |
| visualize_target_grasp_pose=False, |
| print_env_info=False, |
| joint_vel_limits=0.3, |
| ) |
| else: |
| planner = FailAwarePandaArmMotionPlanningSolver( |
| env, |
| debug=False, |
| vis=False, |
| base_pose=env.unwrapped.agent.robot.pose, |
| visualize_target_grasp_pose=False, |
| print_env_info=False, |
| ) |
|
|
| _patch_planner_screw_to_rrt(planner) |
|
|
| tasks = list(getattr(env.unwrapped, "task_list", []) or []) |
| for task_entry in tasks: |
| solve_callable = task_entry.get("solve") |
| if not callable(solve_callable): |
| continue |
| env.unwrapped.evaluate(solve_complete_eval=True) |
| screw_failed = False |
| try: |
| solve_result = solve_callable(env, planner) |
| if isinstance(solve_result, int) and solve_result == -1: |
| screw_failed = True |
| env.unwrapped.failureflag = torch.tensor([True]) |
| env.unwrapped.successflag = torch.tensor([False]) |
| env.unwrapped.current_task_failure = True |
| except ScrewPlanFailure: |
| screw_failed = True |
| env.unwrapped.failureflag = torch.tensor([True]) |
| env.unwrapped.successflag = torch.tensor([False]) |
| env.unwrapped.current_task_failure = True |
| except FailsafeTimeout: |
| break |
|
|
| evaluation = env.unwrapped.evaluate(solve_complete_eval=True) |
| fail_flag = evaluation.get("fail", False) |
| success_flag = evaluation.get("success", False) |
|
|
| if _tensor_to_bool(success_flag): |
| episode_successful = True |
| break |
| if screw_failed or _tensor_to_bool(fail_flag): |
| break |
| else: |
| evaluation = env.unwrapped.evaluate(solve_complete_eval=True) |
| episode_successful = _tensor_to_bool(evaluation.get("success", False)) |
|
|
| episode_successful = episode_successful or _tensor_to_bool( |
| getattr(env, "episode_success", False) |
| ) |
| except SceneGenerationError: |
| episode_successful = False |
| finally: |
| try: |
| env.close() |
| except Exception: |
| pass |
|
|
| return episode_successful |
|
|
|
|
| def _run_episode_with_retry(case: DatasetCase, output_dir: Path) -> tuple[Path, int]: |
| for attempt in range(MAX_SEED_ATTEMPTS): |
| seed = case.base_seed + attempt |
| try: |
| success = _run_one_episode(case=case, seed=seed, output_dir=output_dir) |
| except Exception: |
| continue |
| if not success: |
| continue |
|
|
| h5_path = output_dir / "hdf5_files" / f"{case.env_id}_ep{case.episode}_seed{seed}.h5" |
| if not h5_path.exists(): |
| raise FileNotFoundError(f"Missing expected HDF5: {h5_path}") |
| return h5_path, seed |
| raise RuntimeError( |
| f"[{case.env_id}] Failed to generate successful record in {MAX_SEED_ATTEMPTS} attempts." |
| ) |
|
|
|
|
| def _write_meta(meta_path: Path, payload: dict) -> None: |
| meta_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") |
|
|
|
|
| def _read_meta(meta_path: Path) -> dict: |
| return json.loads(meta_path.read_text(encoding="utf-8")) |
|
|
|
|
| def generate_dataset_case(case: DatasetCase, cache_root: Path) -> GeneratedDataset: |
| case_dir = cache_root / case.cache_key() |
| work_dir = case_dir / "work" |
| resolver_dataset_dir = case_dir / "resolver_dataset" |
| resolver_h5_path = resolver_dataset_dir / f"record_dataset_{case.env_id}.h5" |
| meta_path = case_dir / "meta.json" |
|
|
| if meta_path.exists(): |
| meta = _read_meta(meta_path) |
| raw_h5_path = Path(meta["raw_h5_path"]) |
| if raw_h5_path.exists() and resolver_h5_path.exists(): |
| return GeneratedDataset( |
| case=case, |
| work_dir=work_dir, |
| raw_h5_path=raw_h5_path, |
| resolver_dataset_dir=resolver_dataset_dir, |
| resolver_h5_path=resolver_h5_path, |
| used_seed=int(meta["used_seed"]), |
| ) |
|
|
| case_dir.mkdir(parents=True, exist_ok=True) |
| work_dir.mkdir(parents=True, exist_ok=True) |
| resolver_dataset_dir.mkdir(parents=True, exist_ok=True) |
|
|
| raw_h5_path, used_seed = _run_episode_with_retry(case=case, output_dir=work_dir) |
| shutil.copy2(raw_h5_path, resolver_h5_path) |
|
|
| payload = { |
| "case": asdict(case), |
| "used_seed": used_seed, |
| "raw_h5_path": str(raw_h5_path), |
| "resolver_h5_path": str(resolver_h5_path), |
| } |
| _write_meta(meta_path, payload) |
|
|
| return GeneratedDataset( |
| case=case, |
| work_dir=work_dir, |
| raw_h5_path=raw_h5_path, |
| resolver_dataset_dir=resolver_dataset_dir, |
| resolver_h5_path=resolver_h5_path, |
| used_seed=used_seed, |
| ) |
|
|
|
|
| class DatasetFactoryCache: |
| def __init__(self, cache_root: Path): |
| self.cache_root = cache_root |
| self._memo: dict[str, GeneratedDataset] = {} |
|
|
| def get(self, case: DatasetCase) -> GeneratedDataset: |
| key = case.cache_key() |
| cached = self._memo.get(key) |
| if cached is not None: |
| return cached |
| generated = generate_dataset_case(case, self.cache_root) |
| self._memo[key] = generated |
| return generated |
|
|
|
|
| DatasetFactory = Callable[[DatasetCase], GeneratedDataset] |
|
|
|
|