| | import os |
| |
|
| | import argparse |
| | import json |
| | import shutil |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | from typing import Any, Dict, Iterable, List, Optional, Set |
| | import h5py |
| |
|
| | import gymnasium as gym |
| |
|
| | |
| | from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout |
| | from robomme.robomme_env import * |
| | from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError |
| |
|
| | |
| | import torch |
| |
|
| | |
| | from robomme.robomme_env.utils.planner_fail_safe import ( |
| | FailAwarePandaArmMotionPlanningSolver, |
| | FailAwarePandaStickMotionPlanningSolver, |
| | ScrewPlanFailure, |
| | ) |
| |
|
| | """ |
| | |
| | Script function: Parallel generation of Robomme environment datasets. |
| | This script supports multi-process parallel environment simulation, generating HDF5 datasets containing RGB, depth, segmentation, etc. |
| | Key features include: |
| | 1. Configure environment list and parameters. |
| | 2. Parallel execution of multiple episode simulations. |
| | 3. Use FailAware planner to attempt to solve tasks. |
| | 4. Record data and save as HDF5 file. |
| | 5. Merge multiple temporarily generated HDF5 files into a final dataset. |
| | """ |
| |
|
| | |
| | DEFAULT_ENVS =[ |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | "PickHighlight", |
| |
|
| | |
| | |
| | |
| | |
| | ] |
| |
|
| | |
| | SOURCE_METADATA_ROOT = Path("/data/hongzefu/robomme_benchmark/src/robomme/env_metadata/1206") |
| | VALID_DIFFICULTIES: Set[str] = {"easy", "medium", "hard"} |
| | DATASET_SCREW_MAX_ATTEMPTS = 3 |
| | DATASET_RRT_MAX_ATTEMPTS = 3 |
| |
|
| |
|
| | def _load_env_metadata_records( |
| | env_id: str, |
| | metadata_root: Path, |
| | ) -> List[Dict[str, Any]]: |
| | """ |
| | Read metadata records for an environment from the reference directory to control difficulty and seed. |
| | """ |
| | metadata_path = metadata_root / f"record_dataset_{env_id}_metadata.json" |
| | if not metadata_path.exists(): |
| | raise FileNotFoundError( |
| | f"Metadata file not found for env '{env_id}': {metadata_path}" |
| | ) |
| |
|
| | with metadata_path.open("r", encoding="utf-8") as metadata_file: |
| | payload = json.load(metadata_file) |
| |
|
| | raw_records = payload.get("records") |
| | if not isinstance(raw_records, list) or not raw_records: |
| | raise ValueError( |
| | f"Metadata file has no valid 'records' list: {metadata_path}" |
| | ) |
| |
|
| | normalized_records: List[Dict[str, Any]] = [] |
| | for idx, raw_record in enumerate(raw_records): |
| | if not isinstance(raw_record, dict): |
| | raise ValueError( |
| | f"Invalid metadata record at index {idx} in {metadata_path}" |
| | ) |
| | if "episode" not in raw_record or "seed" not in raw_record or "difficulty" not in raw_record: |
| | raise ValueError( |
| | f"Metadata record missing episode/seed/difficulty at index {idx} in {metadata_path}" |
| | ) |
| |
|
| | try: |
| | episode = int(raw_record["episode"]) |
| | seed = int(raw_record["seed"]) |
| | except (TypeError, ValueError) as exc: |
| | raise ValueError( |
| | f"Metadata record has non-integer episode/seed at index {idx} in {metadata_path}" |
| | ) from exc |
| |
|
| | difficulty_raw = str(raw_record["difficulty"]).strip().lower() |
| | if difficulty_raw not in VALID_DIFFICULTIES: |
| | raise ValueError( |
| | f"Metadata record has invalid difficulty '{raw_record['difficulty']}' " |
| | f"at index {idx} in {metadata_path}. Expected one of {sorted(VALID_DIFFICULTIES)}." |
| | ) |
| |
|
| | normalized_records.append( |
| | { |
| | "episode": episode, |
| | "seed": seed, |
| | "difficulty": difficulty_raw, |
| | } |
| | ) |
| |
|
| | normalized_records.sort(key=lambda rec: rec["episode"]) |
| | print( |
| | f"Loaded {len(normalized_records)} metadata records for {env_id} from {metadata_path}" |
| | ) |
| | return normalized_records |
| |
|
| |
|
| | def _build_seed_candidates_from_metadata( |
| | episode: int, |
| | metadata_records: List[Dict[str, Any]], |
| | ) -> List[Dict[str, Any]]: |
| | """ |
| | Construct candidate (seed, difficulty) list for current episode. |
| | Strictly use only the seed from metadata for the same episode, no cross-episode fallback. |
| | """ |
| | if not metadata_records: |
| | return [] |
| |
|
| | same_episode_records = [rec for rec in metadata_records if rec["episode"] == episode] |
| | if not same_episode_records: |
| | return [] |
| | if len(same_episode_records) > 1: |
| | raise ValueError( |
| | f"Found duplicated metadata records for episode {episode}. " |
| | "Strict mode requires exactly one source record per episode." |
| | ) |
| |
|
| | rec = same_episode_records[0] |
| | return [{"seed": int(rec["seed"]), "difficulty": rec["difficulty"]}] |
| |
|
| | def _tensor_to_bool(value) -> bool: |
| | """ |
| | Helper function: Convert Tensor or numpy array to Python bool type. |
| | Used to handle success/failure flags from different sources. |
| | """ |
| | if value is None: |
| | return False |
| | if isinstance(value, torch.Tensor): |
| | return bool(value.detach().cpu().bool().item()) |
| | if isinstance(value, np.ndarray): |
| | return bool(np.any(value)) |
| | return bool(value) |
| |
|
| |
|
| | def _split_episode_indices(num_episodes: int, max_chunks: int) -> List[List[int]]: |
| | """ |
| | Helper function: Split total episodes into multiple chunks for parallel processing by different processes. |
| | |
| | Args: |
| | num_episodes: Total number of episodes |
| | max_chunks: Max number of chunks (usually equals number of workers) |
| | |
| | Returns: |
| | List containing lists of episode indices |
| | """ |
| | if num_episodes <= 0: |
| | return [] |
| |
|
| | chunk_count = min(max_chunks, num_episodes) |
| | base_size, remainder = divmod(num_episodes, chunk_count) |
| |
|
| | chunks: List[List[int]] = [] |
| | start = 0 |
| | for chunk_idx in range(chunk_count): |
| | |
| | stop = start + base_size + (1 if chunk_idx < remainder else 0) |
| | chunks.append(list(range(start, stop))) |
| | start = stop |
| |
|
| | return chunks |
| |
|
| |
|
| | def _run_episode_attempt( |
| | env_id: str, |
| | episode: int, |
| | seed: int, |
| | temp_dataset_path: Path, |
| | save_video: bool, |
| | difficulty: Optional[str], |
| | ) -> bool: |
| | """ |
| | Run a single episode attempt and report success or failure. |
| | |
| | Main steps: |
| | 1. Initialize environment parameters and Gym environment. |
| | 2. Apply RobommeRecordWrapper for data recording. |
| | 3. Select appropriate planner based on environment type (PandaStick or PandaArm). |
| | 4. Get task list and execute tasks one by one. |
| | 5. Use planner to solve task and handle possible planning failures. |
| | 6. Check task execution result (fail/success). |
| | 7. Return whether episode is finally successful. |
| | """ |
| | print(f"--- Running simulation for episode:{episode}, seed:{seed}, env: {env_id} ---") |
| |
|
| | env: Optional[gym.Env] = None |
| | try: |
| | |
| | env_kwargs = dict( |
| | obs_mode="rgb+depth+segmentation", |
| | control_mode="pd_joint_pos", |
| | render_mode="rgb_array", |
| | reward_mode="dense", |
| | seed=seed, |
| | difficulty=difficulty, |
| | ) |
| | |
| | |
| | if episode <= 5: |
| | env_kwargs["robomme_failure_recovery"] = True |
| | if episode <=2: |
| | env_kwargs["robomme_failure_recovery_mode"] = "z" |
| | else: |
| | env_kwargs["robomme_failure_recovery_mode"] = "xy" |
| |
|
| |
|
| | env = gym.make(env_id, **env_kwargs) |
| | |
| | |
| | env = RobommeRecordWrapper( |
| | env, |
| | dataset=str(temp_dataset_path), |
| | env_id=env_id, |
| | episode=episode, |
| | seed=seed, |
| | save_video=save_video, |
| |
|
| | ) |
| |
|
| | episode_successful = False |
| |
|
| |
|
| | env.reset() |
| |
|
| | |
| | |
| | if env_id == "PatternLock" or env_id == "RouteStick": |
| | planner = FailAwarePandaStickMotionPlanningSolver( |
| | env, |
| | debug=False, |
| | vis=False, |
| | base_pose=env.unwrapped.agent.robot.pose, |
| | visualize_target_grasp_pose=False, |
| | print_env_info=False, |
| | joint_vel_limits=0.3, |
| | ) |
| | else: |
| | planner = FailAwarePandaArmMotionPlanningSolver( |
| | env, |
| | debug=False, |
| | vis=False, |
| | base_pose=env.unwrapped.agent.robot.pose, |
| | visualize_target_grasp_pose=False, |
| | print_env_info=False, |
| | ) |
| |
|
| | original_move_to_pose_with_screw = planner.move_to_pose_with_screw |
| | original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar |
| |
|
| | def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs): |
| | for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1): |
| | try: |
| | result = original_move_to_pose_with_screw(*args, **kwargs) |
| | except ScrewPlanFailure as exc: |
| | print( |
| | f"[DatasetGen] screw planning failed " |
| | f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}" |
| | ) |
| | continue |
| |
|
| | if isinstance(result, int) and result == -1: |
| | print( |
| | f"[DatasetGen] screw planning returned -1 " |
| | f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})" |
| | ) |
| | continue |
| |
|
| | return result |
| |
|
| | print( |
| | "[DatasetGen] screw planning exhausted; " |
| | f"fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)" |
| | ) |
| |
|
| | for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1): |
| | try: |
| | result = original_move_to_pose_with_rrt(*args, **kwargs) |
| | except Exception as exc: |
| | print( |
| | f"[DatasetGen] RRT* planning failed " |
| | f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}" |
| | ) |
| | continue |
| |
|
| | if isinstance(result, int) and result == -1: |
| | print( |
| | f"[DatasetGen] RRT* planning returned -1 " |
| | f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})" |
| | ) |
| | continue |
| |
|
| | return result |
| |
|
| | print("[DatasetGen] screw->RRT* planning exhausted; return -1") |
| | return -1 |
| |
|
| | planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry |
| |
|
| | env.unwrapped.evaluate() |
| | |
| | tasks = list(getattr(env.unwrapped, "task_list", []) or []) |
| |
|
| | print(f"{env_id}: Task list has {len(tasks)} tasks") |
| |
|
| | |
| | for idx, task_entry in enumerate(tasks): |
| | task_name = task_entry.get("name", f"Task {idx}") |
| | print(f"Executing task {idx + 1}/{len(tasks)}: {task_name}") |
| |
|
| | solve_callable = task_entry.get("solve") |
| | if not callable(solve_callable): |
| | raise ValueError( |
| | f"Task '{task_name}' must supply a callable 'solve'." |
| | ) |
| |
|
| | |
| | env.unwrapped.evaluate(solve_complete_eval=True) |
| | screw_failed = False |
| | try: |
| | |
| | solve_result = solve_callable(env, planner) |
| | if isinstance(solve_result, int) and solve_result == -1: |
| | screw_failed = True |
| | print(f"Screw->RRT* planning exhausted during '{task_name}'") |
| | env.unwrapped.failureflag = torch.tensor([True]) |
| | env.unwrapped.successflag = torch.tensor([False]) |
| | env.unwrapped.current_task_failure = True |
| | except ScrewPlanFailure as exc: |
| | |
| | screw_failed = True |
| | print(f"Screw plan failure during '{task_name}': {exc}") |
| | env.unwrapped.failureflag = torch.tensor([True]) |
| | env.unwrapped.successflag = torch.tensor([False]) |
| | env.unwrapped.current_task_failure = True |
| | except FailsafeTimeout as exc: |
| | |
| | print(f"Failsafe: {exc}") |
| | break |
| |
|
| | |
| | evaluation = env.unwrapped.evaluate(solve_complete_eval=True) |
| |
|
| | fail_flag = evaluation.get("fail", False) |
| | success_flag = evaluation.get("success", False) |
| |
|
| | |
| | if _tensor_to_bool(success_flag): |
| | print("All tasks completed successfully.") |
| | episode_successful = True |
| | break |
| |
|
| | if screw_failed or _tensor_to_bool(fail_flag): |
| | print("Encountered failure condition; stopping task sequence.") |
| | break |
| |
|
| | else: |
| | |
| | evaluation = env.unwrapped.evaluate(solve_complete_eval=True) |
| | episode_successful = _tensor_to_bool(evaluation.get("success", False)) |
| |
|
| | |
| | episode_successful = episode_successful or _tensor_to_bool( |
| | getattr(env, "episode_success", False) |
| | ) |
| |
|
| | except SceneGenerationError as exc: |
| | print( |
| | f"Scene generation failed for env {env_id}, episode {episode}, seed {seed}: {exc}" |
| | ) |
| | episode_successful = False |
| | finally: |
| | if env is not None: |
| | try: |
| | env.close() |
| | except Exception as close_exc: |
| | |
| | |
| | print(f"Warning: Exception during env.close() for episode {episode}, seed {seed}: {close_exc}") |
| | |
| | |
| |
|
| | status_text = "SUCCESS" if episode_successful else "FAILED" |
| | print( |
| | f"--- Finished Running simulation for episode:{episode}, seed:{seed}, env: {env_id} [{status_text}] ---" |
| | ) |
| |
|
| | return episode_successful |
| |
|
| |
|
| | def run_env_dataset( |
| | env_id: str, |
| | episode_indices: Iterable[int], |
| | temp_folder: Path, |
| | save_video: bool, |
| | metadata_records: List[Dict[str, Any]], |
| | gpu_id: int, |
| | ) -> List[Dict[str, Any]]: |
| | """ |
| | Run dataset generation for a batch of episodes and save data to temporary folder. |
| | |
| | Args: |
| | env_id: Environment ID |
| | episode_indices: List of episode indices to run |
| | temp_folder: Temporary folder to save data |
| | save_video: Whether to save video |
| | metadata_records: Records from reference dataset metadata |
| | gpu_id: GPU ID to use |
| | |
| | Returns: |
| | Generated episode metadata record list |
| | """ |
| | |
| | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) |
| |
|
| | temp_folder.mkdir(parents=True, exist_ok=True) |
| | episode_indices = list(episode_indices) |
| | if not episode_indices: |
| | return [] |
| |
|
| | if env_id not in DEFAULT_ENVS: |
| | raise ValueError(f"Unsupported environment: {env_id}") |
| |
|
| | |
| | |
| | temp_dataset_path = temp_folder / f"temp_chunk.h5" |
| | episode_records: List[Dict[str, Any]] = [] |
| |
|
| | for episode in episode_indices: |
| | candidate_pairs = _build_seed_candidates_from_metadata(episode, metadata_records) |
| | if not candidate_pairs: |
| | print(f"Episode {episode}: no metadata candidate seeds found, skipping.") |
| | continue |
| |
|
| | episode_success = False |
| | MAX_RETRY_ATTEMPTS = 20 |
| |
|
| | for attempt_idx, candidate in enumerate(candidate_pairs, start=1): |
| | base_seed = int(candidate["seed"]) |
| | difficulty = str(candidate["difficulty"]) |
| | |
| | current_seed = base_seed |
| | for retry_count in range(MAX_RETRY_ATTEMPTS): |
| | if retry_count > 0: |
| | current_seed += 1 |
| |
|
| | print( |
| | f"Episode {episode} attempt {retry_count + 1}/{MAX_RETRY_ATTEMPTS} " |
| | f"with seed={current_seed} (base={base_seed}, diff={difficulty})" |
| | ) |
| |
|
| | try: |
| | success = _run_episode_attempt( |
| | env_id=env_id, |
| | episode=episode, |
| | seed=current_seed, |
| | temp_dataset_path=temp_dataset_path, |
| | save_video=save_video, |
| | difficulty=difficulty, |
| | ) |
| |
|
| | if success: |
| | |
| | episode_records.append( |
| | { |
| | "task": env_id, |
| | "episode": episode, |
| | "seed": current_seed, |
| | "difficulty": difficulty, |
| | } |
| | ) |
| | episode_success = True |
| | break |
| | |
| | print( |
| | f"Episode {episode} failed with seed {current_seed}; retrying with seed+1..." |
| | ) |
| | except Exception as exc: |
| | print( |
| | f"Episode {episode} exception with seed {current_seed}: {exc}; retrying with seed+1..." |
| | ) |
| | |
| | if episode_success: |
| | break |
| |
|
| | if not episode_success: |
| | print( |
| | f"Episode {episode} failed with strict source metadata seed; " |
| | "metadata will not be recorded for this episode." |
| | ) |
| |
|
| | return episode_records |
| |
|
| |
|
| | def _merge_dataset_from_folder( |
| | env_id: str, |
| | temp_folder: Path, |
| | final_dataset_path: Path, |
| | ) -> None: |
| | """ |
| | Merge all episode files from temporary folder into final dataset. |
| | |
| | Args: |
| | env_id: Environment ID |
| | temp_folder: Temporary folder containing episode files |
| | final_dataset_path: Final output HDF5 file path |
| | """ |
| | if not temp_folder.exists() or not temp_folder.is_dir(): |
| | print(f"Warning: Temporary folder {temp_folder} does not exist") |
| | return |
| |
|
| | final_dataset_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | |
| | hdf5_folders = list(temp_folder.glob("*_hdf5_files")) |
| |
|
| | if not hdf5_folders: |
| | print(f"Warning: No HDF5 folders found in {temp_folder}") |
| | return |
| |
|
| | print(f"Merging episodes from {temp_folder} into {final_dataset_path}") |
| |
|
| | |
| | with h5py.File(final_dataset_path, "a") as final_file: |
| | for hdf5_folder in sorted(hdf5_folders): |
| | |
| | h5_files = sorted(hdf5_folder.glob("*.h5")) |
| |
|
| | if not h5_files: |
| | print(f"Warning: No h5 files found in {hdf5_folder}") |
| | continue |
| |
|
| | print(f"Found {len(h5_files)} episode files in {hdf5_folder.name}") |
| |
|
| | |
| | for h5_file in h5_files: |
| | print(f" - Merging {h5_file.name}") |
| |
|
| | try: |
| | with h5py.File(h5_file, "r") as episode_file: |
| | file_keys = list(episode_file.keys()) |
| | if len(file_keys) == 0: |
| | print(f" Warning: {h5_file.name} is empty, skipping...") |
| | continue |
| | |
| | for env_group_name, src_env_group in episode_file.items(): |
| | episode_keys = list(src_env_group.keys()) if isinstance(src_env_group, h5py.Group) else [] |
| | if len(episode_keys) == 0: |
| | print(f" Warning: {env_group_name} in {h5_file.name} has no episodes, skipping...") |
| | continue |
| | |
| | |
| | if env_group_name not in final_file: |
| | final_file.copy(src_env_group, env_group_name) |
| | continue |
| |
|
| | dest_env_group = final_file[env_group_name] |
| | if not isinstance(dest_env_group, h5py.Group): |
| | print(f" Warning: {env_group_name} is not a group, skipping...") |
| | continue |
| |
|
| | |
| | for episode_name in src_env_group.keys(): |
| | if episode_name in dest_env_group: |
| | print(f" Warning: Episode {episode_name} already exists, overwriting...") |
| | del dest_env_group[episode_name] |
| | src_env_group.copy(episode_name, dest_env_group, name=episode_name) |
| | except Exception as e: |
| | print(f" Error merging {h5_file.name}: {e}") |
| | continue |
| |
|
| | |
| | temp_videos_dir = temp_folder / "videos" |
| | final_videos_dir = final_dataset_path.parent / "videos" |
| | if temp_videos_dir.exists() and temp_videos_dir.is_dir(): |
| | final_videos_dir.mkdir(parents=True, exist_ok=True) |
| | moved_count = 0 |
| | for video_path in sorted(temp_videos_dir.glob("*.mp4")): |
| | target_path = final_videos_dir / video_path.name |
| | if target_path.exists(): |
| | stem = target_path.stem |
| | suffix = target_path.suffix |
| | index = 1 |
| | while True: |
| | candidate = final_videos_dir / f"{stem}_dup{index}{suffix}" |
| | if not candidate.exists(): |
| | target_path = candidate |
| | break |
| | index += 1 |
| | try: |
| | shutil.move(str(video_path), str(target_path)) |
| | moved_count += 1 |
| | except Exception as exc: |
| | print(f"Warning: Failed to move video {video_path.name}: {exc}") |
| | if moved_count > 0: |
| | print(f"Moved {moved_count} videos to {final_videos_dir}") |
| |
|
| | |
| | try: |
| | shutil.rmtree(temp_folder) |
| | print(f"Cleaned up temporary folder: {temp_folder}") |
| | except Exception as e: |
| | print(f"Warning: Failed to remove temporary folder {temp_folder}: {e}") |
| |
|
| |
|
| | def _save_episode_metadata( |
| | records: List[Dict[str, Any]], |
| | metadata_path: Path, |
| | env_id: str, |
| | ) -> None: |
| | """Save seed/difficulty metadata for each episode to JSON file.""" |
| | metadata_path.parent.mkdir(parents=True, exist_ok=True) |
| | sorted_records = sorted(records, key=lambda rec: rec.get("episode", -1)) |
| | metadata = { |
| | "env_id": env_id, |
| | "record_count": len(sorted_records), |
| | "records": sorted_records, |
| | } |
| | try: |
| | with metadata_path.open("w", encoding="utf-8") as metadata_file: |
| | json.dump(metadata, metadata_file, indent=2) |
| | print(f"Saved episode metadata to {metadata_path}") |
| | except Exception as exc: |
| | print(f"Warning: Failed to save episode metadata to {metadata_path}: {exc}") |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Robomme Dataset Generator") |
| | parser.add_argument( |
| | "--episodes", |
| | "-n", |
| | type=int, |
| | nargs="+", |
| | default=[0], |
| | help="List of episodes to generate.", |
| | ) |
| | parser.add_argument( |
| | "--save-video", |
| | dest="save_video", |
| | action="store_true", |
| | default=True, |
| | help="Enable video recording via RobommeRecordWrapper (Default: Enabled).", |
| | ) |
| | parser.add_argument( |
| | "--no-save-video", |
| | dest="save_video", |
| | action="store_false", |
| | help="Disable video recording.", |
| | ) |
| | parser.add_argument( |
| | "--gpus", |
| | type=str, |
| | default="1", |
| | help="GPU selection. Supported values: '0', '1', '0,1' (or '1,0'). Default: '0'.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def _parse_gpu_ids(gpu_spec: str) -> List[int]: |
| | """Parse user GPU spec string to a deduplicated GPU id list.""" |
| | valid_gpu_ids = {0, 1} |
| | raw_tokens = [token.strip() for token in gpu_spec.split(",") if token.strip()] |
| | if not raw_tokens: |
| | raise ValueError("GPU spec is empty. Use one of: 0, 1, 0,1") |
| |
|
| | gpu_ids: List[int] = [] |
| | for token in raw_tokens: |
| | try: |
| | gpu_id = int(token) |
| | except ValueError as exc: |
| | raise ValueError( |
| | f"Invalid GPU id '{token}'. Supported values are 0 and 1." |
| | ) from exc |
| |
|
| | if gpu_id not in valid_gpu_ids: |
| | raise ValueError( |
| | f"Unsupported GPU id '{gpu_id}'. Supported values are 0 and 1." |
| | ) |
| | if gpu_id not in gpu_ids: |
| | gpu_ids.append(gpu_id) |
| |
|
| | if not gpu_ids: |
| | raise ValueError("No valid GPU id provided. Use one of: 0, 1, 0,1") |
| | return gpu_ids |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | env_ids: List[str] = ["PickHighlight"] |
| |
|
| | num_workers = 1 |
| | gpu_spec = args.gpus |
| | gpu_ids = _parse_gpu_ids(gpu_spec) |
| | episode_indices = args.episodes |
| |
|
| | for env_id in env_ids: |
| | source_metadata_records = _load_env_metadata_records( |
| | env_id=env_id, |
| | metadata_root=SOURCE_METADATA_ROOT, |
| | ) |
| |
|
| | |
| | temp_folder = Path(f"/data/hongzefu/data_0226-test/temp_{env_id}_episodes") |
| | final_dataset_path = Path(f"/data/hongzefu/data_0226-test/record_dataset_{env_id}.h5") |
| | |
| |
|
| | print(f"\n{'='*80}") |
| | print(f"Environment: {env_id}") |
| | print(f"Episodes: {args.episodes}") |
| | print(f"Workers: {num_workers}") |
| | if len(gpu_ids) == 1: |
| | print(f"GPU mode: Single GPU ({gpu_ids[0]})") |
| | else: |
| | print(f"GPU mode: Multi GPU ({','.join(str(gpu) for gpu in gpu_ids)})") |
| | print(f"Temporary folder: {temp_folder}") |
| | print(f"Final dataset: {final_dataset_path}") |
| | print(f"{'='*80}\n") |
| |
|
| | episode_records: List[Dict[str, Any]] = [] |
| |
|
| | |
| | episode_records = run_env_dataset( |
| | env_id, |
| | episode_indices, |
| | temp_folder, |
| | args.save_video, |
| | source_metadata_records, |
| | gpu_ids[0], |
| | ) |
| |
|
| | |
| | print(f"\nMerging all episodes into final dataset...") |
| | _merge_dataset_from_folder( |
| | env_id, |
| | temp_folder, |
| | final_dataset_path, |
| | ) |
| |
|
| | |
| | metadata_path = final_dataset_path.with_name( |
| | f"{final_dataset_path.stem}_metadata.json" |
| | ) |
| | _save_episode_metadata(episode_records, metadata_path, env_id) |
| |
|
| | print(f"\n✓ Finished! Final dataset saved to: {final_dataset_path}\n") |
| |
|
| | print("✓ All requested environments processed.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|