File size: 28,927 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 | import os
import argparse
import json
import shutil
from pathlib import Path
import numpy as np
from typing import Any, Dict, Iterable, List, Optional, Set
import h5py
import gymnasium as gym
# Import Robomme related environment wrappers and exception classes
from robomme.env_record_wrapper import RobommeRecordWrapper, FailsafeTimeout
from robomme.robomme_env import *
from robomme.robomme_env.utils.SceneGenerationError import SceneGenerationError
# from util import *
import torch
# Import planner and related exception classes
from robomme.robomme_env.utils.planner_fail_safe import (
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
ScrewPlanFailure,
)
"""
Script function: Parallel generation of Robomme environment datasets.
This script supports multi-process parallel environment simulation, generating HDF5 datasets containing RGB, depth, segmentation, etc.
Key features include:
1. Configure environment list and parameters.
2. Parallel execution of multiple episode simulations.
3. Use FailAware planner to attempt to solve tasks.
4. Record data and save as HDF5 file.
5. Merge multiple temporarily generated HDF5 files into a final dataset.
"""
# List of all supported environment module names
DEFAULT_ENVS =[
# "PickXtimes",
#"StopCube",
# "SwingXtimes",
# "BinFill",
# "VideoUnmaskSwap",
# "VideoUnmask",
# "ButtonUnmaskSwap",
# "ButtonUnmask",
# "VideoRepick",
# "VideoPlaceButton",
# "VideoPlaceOrder",
"PickHighlight",
# "InsertPeg",
#'MoveCube',
# "PatternLock",
# "RouteStick"
]
# Reference dataset metadata root directory: used to read difficulty and seed
SOURCE_METADATA_ROOT = Path("/data/hongzefu/robomme_benchmark/src/robomme/env_metadata/1206")
VALID_DIFFICULTIES: Set[str] = {"easy", "medium", "hard"}
DATASET_SCREW_MAX_ATTEMPTS = 3
DATASET_RRT_MAX_ATTEMPTS = 3
def _load_env_metadata_records(
env_id: str,
metadata_root: Path,
) -> List[Dict[str, Any]]:
"""
Read metadata records for an environment from the reference directory to control difficulty and seed.
"""
metadata_path = metadata_root / f"record_dataset_{env_id}_metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(
f"Metadata file not found for env '{env_id}': {metadata_path}"
)
with metadata_path.open("r", encoding="utf-8") as metadata_file:
payload = json.load(metadata_file)
raw_records = payload.get("records")
if not isinstance(raw_records, list) or not raw_records:
raise ValueError(
f"Metadata file has no valid 'records' list: {metadata_path}"
)
normalized_records: List[Dict[str, Any]] = []
for idx, raw_record in enumerate(raw_records):
if not isinstance(raw_record, dict):
raise ValueError(
f"Invalid metadata record at index {idx} in {metadata_path}"
)
if "episode" not in raw_record or "seed" not in raw_record or "difficulty" not in raw_record:
raise ValueError(
f"Metadata record missing episode/seed/difficulty at index {idx} in {metadata_path}"
)
try:
episode = int(raw_record["episode"])
seed = int(raw_record["seed"])
except (TypeError, ValueError) as exc:
raise ValueError(
f"Metadata record has non-integer episode/seed at index {idx} in {metadata_path}"
) from exc
difficulty_raw = str(raw_record["difficulty"]).strip().lower()
if difficulty_raw not in VALID_DIFFICULTIES:
raise ValueError(
f"Metadata record has invalid difficulty '{raw_record['difficulty']}' "
f"at index {idx} in {metadata_path}. Expected one of {sorted(VALID_DIFFICULTIES)}."
)
normalized_records.append(
{
"episode": episode,
"seed": seed,
"difficulty": difficulty_raw,
}
)
normalized_records.sort(key=lambda rec: rec["episode"])
print(
f"Loaded {len(normalized_records)} metadata records for {env_id} from {metadata_path}"
)
return normalized_records
def _build_seed_candidates_from_metadata(
episode: int,
metadata_records: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Construct candidate (seed, difficulty) list for current episode.
Strictly use only the seed from metadata for the same episode, no cross-episode fallback.
"""
if not metadata_records:
return []
same_episode_records = [rec for rec in metadata_records if rec["episode"] == episode]
if not same_episode_records:
return []
if len(same_episode_records) > 1:
raise ValueError(
f"Found duplicated metadata records for episode {episode}. "
"Strict mode requires exactly one source record per episode."
)
rec = same_episode_records[0]
return [{"seed": int(rec["seed"]), "difficulty": rec["difficulty"]}]
def _tensor_to_bool(value) -> bool:
"""
Helper function: Convert Tensor or numpy array to Python bool type.
Used to handle success/failure flags from different sources.
"""
if value is None:
return False
if isinstance(value, torch.Tensor):
return bool(value.detach().cpu().bool().item())
if isinstance(value, np.ndarray):
return bool(np.any(value))
return bool(value)
def _split_episode_indices(num_episodes: int, max_chunks: int) -> List[List[int]]:
"""
Helper function: Split total episodes into multiple chunks for parallel processing by different processes.
Args:
num_episodes: Total number of episodes
max_chunks: Max number of chunks (usually equals number of workers)
Returns:
List containing lists of episode indices
"""
if num_episodes <= 0:
return []
chunk_count = min(max_chunks, num_episodes)
base_size, remainder = divmod(num_episodes, chunk_count)
chunks: List[List[int]] = []
start = 0
for chunk_idx in range(chunk_count):
# If there is a remainder, allocate one extra episode to the first 'remainder' chunks
stop = start + base_size + (1 if chunk_idx < remainder else 0)
chunks.append(list(range(start, stop)))
start = stop
return chunks
def _run_episode_attempt(
env_id: str,
episode: int,
seed: int,
temp_dataset_path: Path,
save_video: bool,
difficulty: Optional[str],
) -> bool:
"""
Run a single episode attempt and report success or failure.
Main steps:
1. Initialize environment parameters and Gym environment.
2. Apply RobommeRecordWrapper for data recording.
3. Select appropriate planner based on environment type (PandaStick or PandaArm).
4. Get task list and execute tasks one by one.
5. Use planner to solve task and handle possible planning failures.
6. Check task execution result (fail/success).
7. Return whether episode is finally successful.
"""
print(f"--- Running simulation for episode:{episode}, seed:{seed}, env: {env_id} ---")
env: Optional[gym.Env] = None
try:
# 1. Environment parameter configuration
env_kwargs = dict(
obs_mode="rgb+depth+segmentation", # Observation mode: RGB + Depth + Segmentation
control_mode="pd_joint_pos", # Control mode: Position control
render_mode="rgb_array", # Render mode
reward_mode="dense", # Reward mode
seed=seed, # Random seed
difficulty=difficulty, # Difficulty setting
)
# Special failure recovery settings for the first few episodes (for testing or demonstration purposes only)
if episode <= 5:
env_kwargs["robomme_failure_recovery"] = True
if episode <=2:
env_kwargs["robomme_failure_recovery_mode"] = "z" # z-axis recovery
else:
env_kwargs["robomme_failure_recovery_mode"] = "xy" # xy-axis recovery
env = gym.make(env_id, **env_kwargs)
# 2. Wrap environment to record data
env = RobommeRecordWrapper(
env,
dataset=str(temp_dataset_path), # Data save path
env_id=env_id,
episode=episode,
seed=seed,
save_video=save_video,
)
episode_successful = False
env.reset()
# 3. Select planner
# PatternLock and RouteStick require Stick planner, others use Arm planner
if env_id == "PatternLock" or env_id == "RouteStick":
planner = FailAwarePandaStickMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
planner = FailAwarePandaArmMotionPlanningSolver(
env,
debug=False,
vis=False,
base_pose=env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
)
original_move_to_pose_with_screw = planner.move_to_pose_with_screw
original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar
def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs):
for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1):
try:
result = original_move_to_pose_with_screw(*args, **kwargs)
except ScrewPlanFailure as exc:
print(
f"[DatasetGen] screw planning failed "
f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}"
)
continue
if isinstance(result, int) and result == -1:
print(
f"[DatasetGen] screw planning returned -1 "
f"(attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})"
)
continue
return result
print(
"[DatasetGen] screw planning exhausted; "
f"fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)"
)
for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1):
try:
result = original_move_to_pose_with_rrt(*args, **kwargs)
except Exception as exc:
print(
f"[DatasetGen] RRT* planning failed "
f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}"
)
continue
if isinstance(result, int) and result == -1:
print(
f"[DatasetGen] RRT* planning returned -1 "
f"(attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})"
)
continue
return result
print("[DatasetGen] screw->RRT* planning exhausted; return -1")
return -1
planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry
env.unwrapped.evaluate()
# Get environment task list
tasks = list(getattr(env.unwrapped, "task_list", []) or [])
print(f"{env_id}: Task list has {len(tasks)} tasks")
# 4. Iterate and execute all subtasks
for idx, task_entry in enumerate(tasks):
task_name = task_entry.get("name", f"Task {idx}")
print(f"Executing task {idx + 1}/{len(tasks)}: {task_name}")
solve_callable = task_entry.get("solve")
if not callable(solve_callable):
raise ValueError(
f"Task '{task_name}' must supply a callable 'solve'."
)
# Evaluate once before executing solve
env.unwrapped.evaluate(solve_complete_eval=True)
screw_failed = False
try:
# 5. Call planner to solve current task
solve_result = solve_callable(env, planner)
if isinstance(solve_result, int) and solve_result == -1:
screw_failed = True
print(f"Screw->RRT* planning exhausted during '{task_name}'")
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except ScrewPlanFailure as exc:
# Planning failure handling
screw_failed = True
print(f"Screw plan failure during '{task_name}': {exc}")
env.unwrapped.failureflag = torch.tensor([True])
env.unwrapped.successflag = torch.tensor([False])
env.unwrapped.current_task_failure = True
except FailsafeTimeout as exc:
# Timeout handling
print(f"Failsafe: {exc}")
break
# Evaluation after task execution
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
fail_flag = evaluation.get("fail", False)
success_flag = evaluation.get("success", False)
# 6. Check success/failure conditions
if _tensor_to_bool(success_flag):
print("All tasks completed successfully.")
episode_successful = True
break
if screw_failed or _tensor_to_bool(fail_flag):
print("Encountered failure condition; stopping task sequence.")
break
else:
# If loop ends normally (no break), check success again
evaluation = env.unwrapped.evaluate(solve_complete_eval=True)
episode_successful = _tensor_to_bool(evaluation.get("success", False))
# 7. Prioritize wrapper's success signal (double check)
episode_successful = episode_successful or _tensor_to_bool(
getattr(env, "episode_success", False)
)
except SceneGenerationError as exc:# Scene generation failure may occur in environments like swingxtimes
print(
f"Scene generation failed for env {env_id}, episode {episode}, seed {seed}: {exc}"
)
episode_successful = False
finally:
if env is not None:
try:
env.close()
except Exception as close_exc:
# Even if close() fails, return success if episode was successful
# Because HDF5 data was written before close() (in write() method)
print(f"Warning: Exception during env.close() for episode {episode}, seed {seed}: {close_exc}")
# If episode was successful, close() exception should not affect return value
# episode_successful was determined before close()
status_text = "SUCCESS" if episode_successful else "FAILED"
print(
f"--- Finished Running simulation for episode:{episode}, seed:{seed}, env: {env_id} [{status_text}] ---"
)
return episode_successful
def run_env_dataset(
env_id: str,
episode_indices: Iterable[int],
temp_folder: Path,
save_video: bool,
metadata_records: List[Dict[str, Any]],
gpu_id: int,
) -> List[Dict[str, Any]]:
"""
Run dataset generation for a batch of episodes and save data to temporary folder.
Args:
env_id: Environment ID
episode_indices: List of episode indices to run
temp_folder: Temporary folder to save data
save_video: Whether to save video
metadata_records: Records from reference dataset metadata
gpu_id: GPU ID to use
Returns:
Generated episode metadata record list
"""
# Set GPU used by current process
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
temp_folder.mkdir(parents=True, exist_ok=True)
episode_indices = list(episode_indices)
if not episode_indices:
return []
if env_id not in DEFAULT_ENVS:
raise ValueError(f"Unsupported environment: {env_id}")
# Pass a temporary h5 file path to wrapper
# Note: wrapper will actually create separate episode files in a subfolder of that path's directory
temp_dataset_path = temp_folder / f"temp_chunk.h5"
episode_records: List[Dict[str, Any]] = []
for episode in episode_indices:
candidate_pairs = _build_seed_candidates_from_metadata(episode, metadata_records)
if not candidate_pairs:
print(f"Episode {episode}: no metadata candidate seeds found, skipping.")
continue
episode_success = False
MAX_RETRY_ATTEMPTS = 20
for attempt_idx, candidate in enumerate(candidate_pairs, start=1):
base_seed = int(candidate["seed"])
difficulty = str(candidate["difficulty"])
current_seed = base_seed
for retry_count in range(MAX_RETRY_ATTEMPTS):
if retry_count > 0:
current_seed += 1
print(
f"Episode {episode} attempt {retry_count + 1}/{MAX_RETRY_ATTEMPTS} "
f"with seed={current_seed} (base={base_seed}, diff={difficulty})"
)
try:
success = _run_episode_attempt(
env_id=env_id,
episode=episode,
seed=current_seed,
temp_dataset_path=temp_dataset_path,
save_video=save_video,
difficulty=difficulty,
)
if success:
# Record successful episode information
episode_records.append(
{
"task": env_id,
"episode": episode,
"seed": current_seed,
"difficulty": difficulty,
}
)
episode_success = True
break # Break retry loop (seed increment loop)
print(
f"Episode {episode} failed with seed {current_seed}; retrying with seed+1..."
)
except Exception as exc:
print(
f"Episode {episode} exception with seed {current_seed}: {exc}; retrying with seed+1..."
)
if episode_success:
break # Break candidate loop
if not episode_success:
print(
f"Episode {episode} failed with strict source metadata seed; "
"metadata will not be recorded for this episode."
)
return episode_records
def _merge_dataset_from_folder(
env_id: str,
temp_folder: Path,
final_dataset_path: Path,
) -> None:
"""
Merge all episode files from temporary folder into final dataset.
Args:
env_id: Environment ID
temp_folder: Temporary folder containing episode files
final_dataset_path: Final output HDF5 file path
"""
if not temp_folder.exists() or not temp_folder.is_dir():
print(f"Warning: Temporary folder {temp_folder} does not exist")
return
final_dataset_path.parent.mkdir(parents=True, exist_ok=True)
# Find subfolders created by RobommeRecordWrapper
# It usually creates directories ending with "_hdf5_files"
hdf5_folders = list(temp_folder.glob("*_hdf5_files"))
if not hdf5_folders:
print(f"Warning: No HDF5 folders found in {temp_folder}")
return
print(f"Merging episodes from {temp_folder} into {final_dataset_path}")
# Open final HDF5 file for append mode writing
with h5py.File(final_dataset_path, "a") as final_file:
for hdf5_folder in sorted(hdf5_folders):
# Get all h5 files in folder
h5_files = sorted(hdf5_folder.glob("*.h5"))
if not h5_files:
print(f"Warning: No h5 files found in {hdf5_folder}")
continue
print(f"Found {len(h5_files)} episode files in {hdf5_folder.name}")
# Merge each episode file
for h5_file in h5_files:
print(f" - Merging {h5_file.name}")
try:
with h5py.File(h5_file, "r") as episode_file:
file_keys = list(episode_file.keys())
if len(file_keys) == 0:
print(f" Warning: {h5_file.name} is empty, skipping...")
continue
for env_group_name, src_env_group in episode_file.items():
episode_keys = list(src_env_group.keys()) if isinstance(src_env_group, h5py.Group) else []
if len(episode_keys) == 0:
print(f" Warning: {env_group_name} in {h5_file.name} has no episodes, skipping...")
continue
# If environment group (e.g. 'PickXtimes') does not exist, copy directly
if env_group_name not in final_file:
final_file.copy(src_env_group, env_group_name)
continue
dest_env_group = final_file[env_group_name]
if not isinstance(dest_env_group, h5py.Group):
print(f" Warning: {env_group_name} is not a group, skipping...")
continue
# If environment group exists, copy episodes one by one
for episode_name in src_env_group.keys():
if episode_name in dest_env_group:
print(f" Warning: Episode {episode_name} already exists, overwriting...")
del dest_env_group[episode_name]
src_env_group.copy(episode_name, dest_env_group, name=episode_name)
except Exception as e:
print(f" Error merging {h5_file.name}: {e}")
continue
# Keep videos: wrapper writes videos to 'videos' under temp dir, move to final dir before cleanup
temp_videos_dir = temp_folder / "videos"
final_videos_dir = final_dataset_path.parent / "videos"
if temp_videos_dir.exists() and temp_videos_dir.is_dir():
final_videos_dir.mkdir(parents=True, exist_ok=True)
moved_count = 0
for video_path in sorted(temp_videos_dir.glob("*.mp4")):
target_path = final_videos_dir / video_path.name
if target_path.exists():
stem = target_path.stem
suffix = target_path.suffix
index = 1
while True:
candidate = final_videos_dir / f"{stem}_dup{index}{suffix}"
if not candidate.exists():
target_path = candidate
break
index += 1
try:
shutil.move(str(video_path), str(target_path))
moved_count += 1
except Exception as exc:
print(f"Warning: Failed to move video {video_path.name}: {exc}")
if moved_count > 0:
print(f"Moved {moved_count} videos to {final_videos_dir}")
# Clean up temporary folder after successful merge
try:
shutil.rmtree(temp_folder)
print(f"Cleaned up temporary folder: {temp_folder}")
except Exception as e:
print(f"Warning: Failed to remove temporary folder {temp_folder}: {e}")
def _save_episode_metadata(
records: List[Dict[str, Any]],
metadata_path: Path,
env_id: str,
) -> None:
"""Save seed/difficulty metadata for each episode to JSON file."""
metadata_path.parent.mkdir(parents=True, exist_ok=True)
sorted_records = sorted(records, key=lambda rec: rec.get("episode", -1))
metadata = {
"env_id": env_id,
"record_count": len(sorted_records),
"records": sorted_records,
}
try:
with metadata_path.open("w", encoding="utf-8") as metadata_file:
json.dump(metadata, metadata_file, indent=2)
print(f"Saved episode metadata to {metadata_path}")
except Exception as exc:
print(f"Warning: Failed to save episode metadata to {metadata_path}: {exc}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Robomme Dataset Generator")
parser.add_argument(
"--episodes",
"-n",
type=int,
nargs="+",
default=[0],
help="List of episodes to generate.",
)
parser.add_argument(
"--save-video",
dest="save_video",
action="store_true",
default=True,
help="Enable video recording via RobommeRecordWrapper (Default: Enabled).",
)
parser.add_argument(
"--no-save-video",
dest="save_video",
action="store_false",
help="Disable video recording.",
)
parser.add_argument(
"--gpus",
type=str,
default="1",
help="GPU selection. Supported values: '0', '1', '0,1' (or '1,0'). Default: '0'.",
)
return parser.parse_args()
def _parse_gpu_ids(gpu_spec: str) -> List[int]:
"""Parse user GPU spec string to a deduplicated GPU id list."""
valid_gpu_ids = {0, 1}
raw_tokens = [token.strip() for token in gpu_spec.split(",") if token.strip()]
if not raw_tokens:
raise ValueError("GPU spec is empty. Use one of: 0, 1, 0,1")
gpu_ids: List[int] = []
for token in raw_tokens:
try:
gpu_id = int(token)
except ValueError as exc:
raise ValueError(
f"Invalid GPU id '{token}'. Supported values are 0 and 1."
) from exc
if gpu_id not in valid_gpu_ids:
raise ValueError(
f"Unsupported GPU id '{gpu_id}'. Supported values are 0 and 1."
)
if gpu_id not in gpu_ids:
gpu_ids.append(gpu_id)
if not gpu_ids:
raise ValueError("No valid GPU id provided. Use one of: 0, 1, 0,1")
return gpu_ids
def main() -> None:
args = parse_args()
env_ids: List[str] = ["PickHighlight"]
num_workers = 1
gpu_spec = args.gpus
gpu_ids = _parse_gpu_ids(gpu_spec)
episode_indices = args.episodes
for env_id in env_ids:
source_metadata_records = _load_env_metadata_records(
env_id=env_id,
metadata_root=SOURCE_METADATA_ROOT,
)
# Create shared temporary folder for all episodes
temp_folder = Path(f"/data/hongzefu/data_0226-test/temp_{env_id}_episodes")
final_dataset_path = Path(f"/data/hongzefu/data_0226-test/record_dataset_{env_id}.h5")
#final_dataset_path = Path(f"/data/hongzefu/dataset_generate/record_dataset_{env_id}.h5")
print(f"\n{'='*80}")
print(f"Environment: {env_id}")
print(f"Episodes: {args.episodes}")
print(f"Workers: {num_workers}")
if len(gpu_ids) == 1:
print(f"GPU mode: Single GPU ({gpu_ids[0]})")
else:
print(f"GPU mode: Multi GPU ({','.join(str(gpu) for gpu in gpu_ids)})")
print(f"Temporary folder: {temp_folder}")
print(f"Final dataset: {final_dataset_path}")
print(f"{'='*80}\n")
episode_records: List[Dict[str, Any]] = []
# Single worker mode
episode_records = run_env_dataset(
env_id,
episode_indices,
temp_folder,
args.save_video,
source_metadata_records,
gpu_ids[0], # gpu_id
)
# Merge episodes into final dataset
print(f"\nMerging all episodes into final dataset...")
_merge_dataset_from_folder(
env_id,
temp_folder,
final_dataset_path,
)
# 4. Save metadata
metadata_path = final_dataset_path.with_name(
f"{final_dataset_path.stem}_metadata.json"
)
_save_episode_metadata(episode_records, metadata_path, env_id)
print(f"\n✓ Finished! Final dataset saved to: {final_dataset_path}\n")
print("✓ All requested environments processed.")
if __name__ == "__main__":
main()
|