RoboMME_Interactive_Demo_cpu / tests /_shared /dataset_generation.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
from __future__ import annotations
import json
import shutil
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import torch
import gymnasium as gym
from tests._shared.repo_paths import ensure_src_on_path
ensure_src_on_path(__file__)
from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout # noqa: E402
from robomme.robomme_env import * # noqa: F401,F403,E402
from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError # noqa: E402
from robomme.robomme_env.utils.planner_fail_safe import ( # noqa: E402
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
ScrewPlanFailure,
)
DATASET_SCREW_MAX_ATTEMPTS = 3
DATASET_RRT_MAX_ATTEMPTS = 3
MAX_SEED_ATTEMPTS = 30
@dataclass(frozen=True)
class DatasetCase:
env_id: str
episode: int
base_seed: int
difficulty: Optional[str]
save_video: bool
mode_tag: str
def cache_key(self) -> str:
difficulty = self.difficulty if self.difficulty else "none"
return (
f"{self.env_id}_ep{self.episode}_{difficulty}_"
f"{self.base_seed}_{int(self.save_video)}_{self.mode_tag}"
)
@dataclass(frozen=True)
class GeneratedDataset:
case: DatasetCase
work_dir: Path
raw_h5_path: Path
resolver_dataset_dir: Path
resolver_h5_path: Path
used_seed: int
def _tensor_to_bool(value) -> bool:
if value is None:
return False
if isinstance(value, torch.Tensor):
return bool(value.detach().cpu().bool().item())
if isinstance(value, np.ndarray):
return bool(np.any(value))
return bool(value)
def _patch_planner_screw_to_rrt(planner) -> None:
original_screw = planner.move_to_pose_with_screw
original_rrt = planner.move_to_pose_with_RRTStar
def _move_screw_then_rrt(*args, **kwargs):
for _ in range(DATASET_SCREW_MAX_ATTEMPTS):
try:
result = original_screw(*args, **kwargs)
except ScrewPlanFailure:
continue
if isinstance(result, int) and result == -1:
continue
return result
for _ in range(DATASET_RRT_MAX_ATTEMPTS):
try:
result = original_rrt(*args, **kwargs)
except Exception:
continue
if isinstance(result, int) and result == -1:
continue
return result
return -1
planner.move_to_pose_with_screw = _move_screw_then_rrt
def _run_one_episode(
case: DatasetCase,
seed: int,
output_dir: Path,
) -> bool:
env_kwargs = dict(
obs_mode="rgb+depth+segmentation",
control_mode="pd_joint_pos",
render_mode="rgb_array",
reward_mode="dense",
seed=seed,
difficulty=case.difficulty,
)
if case.episode <= 5:
env_kwargs["robomme_failure_recovery"] = True
env_kwargs["robomme_failure_recovery_mode"] = "z" if case.episode <= 2 else "xy"
env = gym.make(case.env_id, **env_kwargs)
env = RobommeRecordWrapper(
env,
dataset=str(output_dir),
env_id=case.env_id,
episode=case.episode,
seed=seed,
save_video=case.save_video,
)
episode_successful = False
try:
env.reset()
is_stick = case.env_id in ("PatternLock", "RouteStick")
if is_stick:
planner = FailAwarePandaStickMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
planner = FailAwarePandaArmMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
)
_patch_planner_screw_to_rrt(planner)
tasks = list(getattr(env.unwrapped, "task_list", []) or [])
for task_entry in tasks:
solve_callable = task_entry.get("solve")
if not callable(solve_callable):
continue
env.unwrapped.evaluate(solve_complete_eval=True)
screw_failed = False
try:
solve_result = solve_callable(env, planner)
if isinstance(solve_result, int) and solve_result == -1:
screw_failed = True
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except ScrewPlanFailure:
screw_failed = True
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except FailsafeTimeout:
break
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
fail_flag = evaluation.get("fail", False)
success_flag = evaluation.get("success", False)
if _tensor_to_bool(success_flag):
episode_successful = True
break
if screw_failed or _tensor_to_bool(fail_flag):
break
else:
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
episode_successful = _tensor_to_bool(evaluation.get("success", False))
episode_successful = episode_successful or _tensor_to_bool(
getattr(env, "episode_success", False)
)
except SceneGenerationError:
episode_successful = False
finally:
try:
env.close()
except Exception:
pass
return episode_successful
def _run_episode_with_retry(case: DatasetCase, output_dir: Path) -> tuple[Path, int]:
for attempt in range(MAX_SEED_ATTEMPTS):
seed = case.base_seed + attempt
try:
success = _run_one_episode(case=case, seed=seed, output_dir=output_dir)
except Exception:
continue
if not success:
continue
h5_path = output_dir / "hdf5_files" / f"{case.env_id}_ep{case.episode}_seed{seed}.h5"
if not h5_path.exists():
raise FileNotFoundError(f"Missing expected HDF5: {h5_path}")
return h5_path, seed
raise RuntimeError(
f"[{case.env_id}] Failed to generate successful record in {MAX_SEED_ATTEMPTS} attempts."
)
def _write_meta(meta_path: Path, payload: dict) -> None:
meta_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _read_meta(meta_path: Path) -> dict:
return json.loads(meta_path.read_text(encoding="utf-8"))
def generate_dataset_case(case: DatasetCase, cache_root: Path) -> GeneratedDataset:
case_dir = cache_root / case.cache_key()
work_dir = case_dir / "work"
resolver_dataset_dir = case_dir / "resolver_dataset"
resolver_h5_path = resolver_dataset_dir / f"record_dataset_{case.env_id}.h5"
meta_path = case_dir / "meta.json"
if meta_path.exists():
meta = _read_meta(meta_path)
raw_h5_path = Path(meta["raw_h5_path"])
if raw_h5_path.exists() and resolver_h5_path.exists():
return GeneratedDataset(
case=case,
work_dir=work_dir,
raw_h5_path=raw_h5_path,
resolver_dataset_dir=resolver_dataset_dir,
resolver_h5_path=resolver_h5_path,
used_seed=int(meta["used_seed"]),
)
case_dir.mkdir(parents=True, exist_ok=True)
work_dir.mkdir(parents=True, exist_ok=True)
resolver_dataset_dir.mkdir(parents=True, exist_ok=True)
raw_h5_path, used_seed = _run_episode_with_retry(case=case, output_dir=work_dir)
shutil.copy2(raw_h5_path, resolver_h5_path)
payload = {
"case": asdict(case),
"used_seed": used_seed,
"raw_h5_path": str(raw_h5_path),
"resolver_h5_path": str(resolver_h5_path),
}
_write_meta(meta_path, payload)
return GeneratedDataset(
case=case,
work_dir=work_dir,
raw_h5_path=raw_h5_path,
resolver_dataset_dir=resolver_dataset_dir,
resolver_h5_path=resolver_h5_path,
used_seed=used_seed,
)
class DatasetFactoryCache:
def __init__(self, cache_root: Path):
self.cache_root = cache_root
self._memo: dict[str, GeneratedDataset] = {}
def get(self, case: DatasetCase) -> GeneratedDataset:
key = case.cache_key()
cached = self._memo.get(key)
if cached is not None:
return cached
generated = generate_dataset_case(case, self.cache_root)
self._memo[key] = generated
return generated
DatasetFactory = Callable[[DatasetCase], GeneratedDataset]