| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| In this file, we define 3 types of datasets: |
| 1. LeRobotSingleDataset: a single dataset for a given embodiment tag |
| 2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags |
| 3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag, |
| with caching for the video frames |
| |
| See `scripts/load_dataset.py` for examples on how to use these datasets. |
| """ |
| import os |
| import hashlib |
| import json, torch |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Sequence |
| import os, random |
| import numpy as np |
| import pandas as pd |
| from pydantic import BaseModel, Field, ValidationError |
| from torch.utils.data import Dataset |
| from tqdm import tqdm |
| from PIL import Image |
|
|
| from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps |
|
|
| from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID |
| from starVLA.dataloader.gr00t_lerobot.schema import ( |
| DatasetMetadata, |
| DatasetStatisticalValues, |
| LeRobotModalityMetadata, |
| LeRobotStateActionMetadata, |
| ) |
| from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform |
|
|
| from functools import partial |
| from typing import Tuple, List |
| import pickle |
|
|
| |
| LE_ROBOT_MODALITY_FILENAME = "meta/modality.json" |
| LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl" |
| LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl" |
| LE_ROBOT_INFO_FILENAME = "meta/info.json" |
| LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json" |
| LE_ROBOT_DATA_FILENAME = "data/*/*.parquet" |
| LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl" |
| EPSILON = 5e-4 |
|
|
| |
| LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet" |
| LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet" |
|
|
|
|
| |
| |
| |
|
|
| STANDARD_ACTION_DIM = 37 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| ACTION_REPRESENTATION_SLICES = { |
| |
| "franka": slice(7, 14), |
|
|
| |
| "gr1": slice(0, 29), |
|
|
| |
| "real_world_franka": slice(29, 37), |
| } |
|
|
| STANDARD_STATE_DIM = 74 |
| |
| |
| |
| |
|
|
| STATE_REPRESENTATION_SLICES = { |
| |
| "franka": slice(0, 8), |
| |
| "real_world_franka": slice(8, 16), |
| |
| "gr1": slice(16, 74), |
| } |
|
|
|
|
| def standardize_action_representation( |
| action: np.ndarray, embodiment_tag: str |
| ) -> np.ndarray: |
| """Map per-robot action to a fixed-size standard action vector.""" |
| target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag) |
|
|
| |
| if target_slice is None: |
| raise ValueError( |
| f"Unknown embodiment tag '{embodiment_tag}' for action mapping. " |
| f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}" |
| ) |
|
|
| expected_dim = target_slice.stop - target_slice.start |
| if action.shape[-1] != expected_dim: |
| raise ValueError( |
| f"Action dim mismatch for tag '{embodiment_tag}': " |
| f"{action.shape[-1]=} vs expected {expected_dim}." |
| ) |
|
|
| standard = np.zeros( |
| (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype |
| ) |
| standard[..., target_slice] = action |
| return standard |
|
|
|
|
| def standardize_state_representation( |
| state: np.ndarray, embodiment_tag: str |
| ) -> np.ndarray: |
| """Map per-robot state to a fixed-size standard state vector.""" |
|
|
| target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag) |
|
|
| |
| if target_slice is None: |
| raise ValueError( |
| f"Unknown embodiment tag '{embodiment_tag}' for state mapping. " |
| f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}" |
| ) |
|
|
| expected_dim = target_slice.stop - target_slice.start |
| if state.shape[-1] != expected_dim: |
| raise ValueError( |
| f"State dim mismatch for tag '{embodiment_tag}': " |
| f"{state.shape[-1]=} vs expected {expected_dim}." |
| ) |
|
|
| standard = np.zeros( |
| (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype |
| ) |
| standard[..., target_slice] = state |
| return standard |
|
|
|
|
| def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict: |
| """Calculate the dataset statistics of all columns for a list of parquet files.""" |
| |
| all_low_dim_data_list = [] |
| |
| |
| for parquet_path in tqdm( |
| sorted(list(parquet_paths)), |
| desc="Collecting all parquet files...", |
| ): |
| |
| parquet_data = pd.read_parquet(parquet_path) |
| parquet_data = parquet_data |
| all_low_dim_data_list.append(parquet_data) |
| |
| all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0) |
| |
| dataset_statistics = {} |
| for le_modality in all_low_dim_data.columns: |
| if le_modality.startswith("annotation."): |
| continue |
| print(f"Computing statistics for {le_modality}...") |
| np_data = np.vstack( |
| [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]] |
| ) |
| dataset_statistics[le_modality] = { |
| "mean": np.mean(np_data, axis=0).tolist(), |
| "std": np.std(np_data, axis=0).tolist(), |
| "min": np.min(np_data, axis=0).tolist(), |
| "max": np.max(np_data, axis=0).tolist(), |
| "q01": np.quantile(np_data, 0.01, axis=0).tolist(), |
| "q99": np.quantile(np_data, 0.99, axis=0).tolist(), |
| } |
| return dataset_statistics |
|
|
|
|
| class ModalityConfig(BaseModel): |
| """Configuration for a modality.""" |
|
|
| delta_indices: list[int] |
| """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" |
| modality_keys: list[str] |
| """The keys to load for the modality in the dataset.""" |
|
|
|
|
| class LeRobotSingleDataset(Dataset): |
| """ |
| Base dataset class for LeRobot that supports sharding. |
| """ |
| def __init__( |
| self, |
| dataset_path: Path | str, |
| modality_configs: dict[str, ModalityConfig], |
| embodiment_tag: str | EmbodimentTag, |
| video_backend: str = "decord", |
| video_backend_kwargs: dict | None = None, |
| transforms: ComposedModalityTransform | None = None, |
| delete_pause_frame: bool = False, |
| **kwargs, |
| ): |
| """ |
| Initialize the dataset. |
| |
| Args: |
| dataset_path (Path | str): The path to the dataset. |
| modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations. |
| See `ModalityConfig` for more details. |
| video_backend (str): Backend for video reading. |
| video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader. |
| transforms (ComposedModalityTransform): The transforms to apply to the dataset. |
| embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment" |
| """ |
| |
| if not Path(dataset_path).exists(): |
| raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") |
| data_cfg = kwargs.get("data_cfg", {}) or {} |
| |
| self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") |
| self.load_video = data_cfg.get("load_video", True) |
| self.num_history_steps = int(data_cfg.get("num_history_steps", 0) or 0) |
|
|
| self.delete_pause_frame = delete_pause_frame |
|
|
| |
| if self.load_video: |
| self.modality_configs = modality_configs |
| else: |
| self.modality_configs = { |
| modality: config |
| for modality, config in modality_configs.items() |
| if modality != "video" |
| } |
| self.video_backend = video_backend |
| self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {} |
| self.transforms = ( |
| transforms if transforms is not None else ComposedModalityTransform(transforms=[]) |
| ) |
|
|
| self._dataset_path = Path(dataset_path) |
| self._dataset_name = self._dataset_path.name |
| self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name) |
| if isinstance(embodiment_tag, EmbodimentTag): |
| self.tag = embodiment_tag.value |
| else: |
| self.tag = embodiment_tag |
|
|
| self._metadata = self._get_metadata(EmbodimentTag(self.tag)) |
|
|
| |
| self._lerobot_modality_meta = self._get_lerobot_modality_meta() |
| self._lerobot_info_meta = self._get_lerobot_info_meta() |
| self._data_path_pattern = self._get_data_path_pattern() |
| self._video_path_pattern = self._get_video_path_pattern() |
| self._chunk_size = self._get_chunk_size() |
| self._tasks = self._get_tasks() |
| self.curr_traj_data = None |
| self.curr_traj_id = None |
|
|
| self._trajectory_ids, self._trajectory_lengths = self._get_trajectories() |
| self._modality_keys = self._get_modality_keys() |
| self._delta_indices = self._get_delta_indices() |
| self._all_steps = self._get_all_steps() |
| self.set_transforms_metadata(self.metadata) |
| self.set_epoch(0) |
|
|
| print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}") |
|
|
|
|
| |
| self._check_integrity() |
|
|
| @property |
| def dataset_path(self) -> Path: |
| """The path to the dataset that contains the METADATA_FILENAME file.""" |
| return self._dataset_path |
|
|
| @property |
| def metadata(self) -> DatasetMetadata: |
| """The metadata for the dataset, loaded from metadata.json in the dataset directory""" |
| return self._metadata |
|
|
| @property |
| def trajectory_ids(self) -> np.ndarray: |
| """The trajectory IDs in the dataset, stored as a 1D numpy array of strings.""" |
| return self._trajectory_ids |
|
|
| @property |
| def trajectory_lengths(self) -> np.ndarray: |
| """The trajectory lengths in the dataset, stored as a 1D numpy array of integers. |
| The order of the lengths is the same as the order of the trajectory IDs. |
| """ |
| return self._trajectory_lengths |
|
|
| @property |
| def all_steps(self) -> list[tuple[int, int]]: |
| """The trajectory IDs and base indices for all steps in the dataset. |
| Example: |
| self.trajectory_ids: [0, 1, 2] |
| self.trajectory_lengths: [3, 2, 4] |
| return: [ |
| ("traj_0", 0), ("traj_0", 1), ("traj_0", 2), |
| ("traj_1", 0), ("traj_1", 1), |
| ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3) |
| ] |
| """ |
| return self._all_steps |
|
|
| @property |
| def modality_keys(self) -> dict: |
| """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality. |
| |
| Example: { |
| "video": ["video.image_side_0", "video.image_side_1"], |
| "state": ["state.eef_position", "state.eef_rotation"], |
| "action": ["action.eef_position", "action.eef_rotation"], |
| "language": ["language.human.task"], |
| "timestamp": ["timestamp"], |
| "reward": ["reward"], |
| } |
| """ |
| return self._modality_keys |
|
|
| @property |
| def delta_indices(self) -> dict[str, np.ndarray]: |
| """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.""" |
| return self._delta_indices |
|
|
| @property |
| def dataset_name(self) -> str: |
| """The name of the dataset.""" |
| return self._dataset_name |
|
|
| @property |
| def lerobot_modality_meta(self) -> LeRobotModalityMetadata: |
| """The metadata for the LeRobot dataset.""" |
| return self._lerobot_modality_meta |
|
|
| @property |
| def lerobot_info_meta(self) -> dict: |
| """The metadata for the LeRobot dataset.""" |
| return self._lerobot_info_meta |
|
|
| @property |
| def data_path_pattern(self) -> str: |
| """The path pattern for the LeRobot dataset.""" |
| return self._data_path_pattern |
|
|
| @property |
| def video_path_pattern(self) -> str: |
| """The path pattern for the LeRobot dataset.""" |
| return self._video_path_pattern |
|
|
| @property |
| def chunk_size(self) -> int: |
| """The chunk size for the LeRobot dataset.""" |
| return self._chunk_size |
|
|
| @property |
| def tasks(self) -> pd.DataFrame: |
| """The tasks for the dataset.""" |
| return self._tasks |
|
|
| def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata: |
| """Get the metadata for the dataset. |
| |
| Returns: |
| dict: The metadata for the dataset. |
| """ |
|
|
| |
| modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME |
| assert ( |
| modality_meta_path.exists() |
| ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" |
| |
| simplified_modality_meta: dict[str, dict] = {} |
| with open(modality_meta_path, "r") as f: |
| le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) |
| for modality in ["state", "action"]: |
| simplified_modality_meta[modality] = {} |
| le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr( |
| le_modality_meta, modality |
| ) |
| for subkey in le_state_action_meta: |
| state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype) |
| if np.issubdtype(state_action_dtype, np.floating): |
| continuous = True |
| else: |
| continuous = False |
| simplified_modality_meta[modality][subkey] = { |
| "absolute": le_state_action_meta[subkey].absolute, |
| "rotation_type": le_state_action_meta[subkey].rotation_type, |
| "shape": [ |
| le_state_action_meta[subkey].end - le_state_action_meta[subkey].start |
| ], |
| "continuous": continuous, |
| } |
|
|
| |
| le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME |
| assert ( |
| le_info_path.exists() |
| ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}" |
| with open(le_info_path, "r") as f: |
| le_info = json.load(f) |
| simplified_modality_meta["video"] = {} |
| for new_key in le_modality_meta.video: |
| original_key = le_modality_meta.video[new_key].original_key |
| if original_key is None: |
| original_key = new_key |
| le_video_meta = le_info["features"][original_key] |
| height = le_video_meta["shape"][le_video_meta["names"].index("height")] |
| width = le_video_meta["shape"][le_video_meta["names"].index("width")] |
| |
| try: |
| channels = le_video_meta["shape"][le_video_meta["names"].index("channel")] |
| fps = le_video_meta["video_info"]["video.fps"] |
| except (ValueError, KeyError): |
| |
| channels = le_video_meta["info"]["video.channels"] |
| fps = le_video_meta["info"]["video.fps"] |
| simplified_modality_meta["video"][new_key] = { |
| "resolution": [width, height], |
| "channels": channels, |
| "fps": fps, |
| } |
|
|
| |
| stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME |
| try: |
| with open(stats_path, "r") as f: |
| le_statistics = json.load(f) |
| for stat in le_statistics.values(): |
| DatasetStatisticalValues.model_validate(stat) |
| except (FileNotFoundError, ValidationError) as e: |
| print(f"Failed to load dataset statistics: {e}") |
| print(f"Calculating dataset statistics for {self.dataset_name}") |
| |
| parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME)) |
| parquet_files_filtered = [] |
| |
| for pf in parquet_files: |
| if "episode_033675.parquet" in pf.name: |
| continue |
| parquet_files_filtered.append(pf) |
| |
| le_statistics = calculate_dataset_statistics(parquet_files_filtered) |
| with open(stats_path, "w") as f: |
| json.dump(le_statistics, f, indent=4) |
| dataset_statistics = {} |
| for our_modality in ["state", "action"]: |
| dataset_statistics[our_modality] = {} |
| for subkey in simplified_modality_meta[our_modality]: |
| dataset_statistics[our_modality][subkey] = {} |
| state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}") |
| assert isinstance(state_action_meta, LeRobotStateActionMetadata) |
| le_modality = state_action_meta.original_key |
| for stat_name in le_statistics[le_modality]: |
| indices = np.arange( |
| state_action_meta.start, |
| state_action_meta.end, |
| ) |
| stat = np.array(le_statistics[le_modality][stat_name]) |
| dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist() |
|
|
| |
| metadata = DatasetMetadata( |
| statistics=dataset_statistics, |
| modalities=simplified_modality_meta, |
| embodiment_tag=embodiment_tag, |
| ) |
|
|
| return metadata |
|
|
| def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]: |
| """Get the trajectories in the dataset.""" |
| |
| |
| if self._lerobot_version == "v2.0": |
| file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME |
| with open(file_path, "r") as f: |
| episode_metadata = [json.loads(line) for line in f] |
| trajectory_ids = [] |
| trajectory_lengths = [] |
| for episode in episode_metadata: |
| trajectory_ids.append(episode["episode_index"]) |
| trajectory_lengths.append(episode["length"]) |
| return np.array(trajectory_ids), np.array(trajectory_lengths) |
| |
| elif self._lerobot_version == "v3.0": |
| file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME)) |
| trajectory_ids = [] |
| trajectory_lengths = [] |
| |
| |
| |
| self.trajectory_ids_to_metadata = {} |
| for file_path in file_paths: |
| episodes_data = pd.read_parquet(file_path) |
| for index, episode in episodes_data.iterrows(): |
| trajectory_ids.append(episode["episode_index"]) |
| trajectory_lengths.append(episode["length"]) |
|
|
| |
| episode_meta = { |
| "data/chunk_index": episode["data/chunk_index"], |
| "data/file_index": episode["data/file_index"], |
| "data/file_from_index": index, |
| } |
| if self.load_video: |
| episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[ |
| "videos/observation.images.wrist/from_timestamp" |
| ] |
| self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta |
|
|
| |
| return np.array(trajectory_ids), np.array(trajectory_lengths) |
|
|
| def _get_all_steps(self) -> list[tuple[int, int]]: |
| """Get the trajectory IDs and base indices for all steps in the dataset. |
| |
| Returns: |
| list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples. |
| """ |
| |
| config_key = self._get_steps_config_key() |
| |
| |
| |
| |
| |
| steps_filename = "steps_data_index.pkl" |
|
|
|
|
| steps_path = self.dataset_path / "meta" / steps_filename |
| |
| |
| try: |
| if steps_path.exists(): |
| with open(steps_path, "rb") as f: |
| cached_data = pickle.load(f) |
| return cached_data["steps"] |
|
|
| except (FileNotFoundError, pickle.PickleError, KeyError) as e: |
| print(f"Failed to load cached steps: {e}") |
| print("Computing steps from scratch...") |
|
|
| |
| all_steps = self._get_all_steps_single_process() |
| |
| |
| try: |
| cache_data = { |
| "config_key": config_key, |
| "steps": all_steps, |
| "num_trajectories": len(self.trajectory_ids), |
| "total_steps": len(all_steps), |
| "computed_timestamp": pd.Timestamp.now().isoformat(), |
| "delete_pause_frame": self.delete_pause_frame, |
| } |
| |
| |
| steps_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(steps_path, "wb") as f: |
| pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL) |
| print(f"Cached steps saved to {steps_path}") |
| except Exception as e: |
| print(f"Failed to cache steps: {e}") |
| |
| return all_steps |
|
|
| def _get_steps_config_key(self) -> str: |
| """Generate a configuration key for steps caching.""" |
| config_dict = { |
| "delete_pause_frame": self.delete_pause_frame, |
| "dataset_name": self.dataset_name, |
| } |
| |
| config_str = str(sorted(config_dict.items())) |
| return hashlib.md5(config_str.encode()).hexdigest()[:12] |
|
|
|
|
| def _get_all_steps_single_process(self) -> list[tuple[int, int]]: |
| """Original single-process implementation as fallback.""" |
| all_steps: list[tuple[int, int]] = [] |
| skipped_trajectories = 0 |
| processed_trajectories = 0 |
| |
| |
| has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0 |
| |
| for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"): |
| try: |
| if self._lerobot_version == "v2.0": |
| data = self.get_trajectory_data(trajectory_id) |
| elif self._lerobot_version == "v3.0": |
| data = self.get_trajectory_data_lerobot_v3(trajectory_id) |
| |
| trajectory_skipped = False |
| |
| |
| if has_language_modality: |
| self.curr_traj_data = data |
|
|
| language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0) |
| if not language_instruction or language_instruction[0] == "": |
| print(f"Skipping trajectory {trajectory_id} due to empty language instruction") |
| skipped_trajectories += 1 |
| trajectory_skipped = True |
| continue |
|
|
| except Exception as e: |
| print(f"Skipping trajectory {trajectory_id} due to read error: {e}") |
| skipped_trajectories += 1 |
| trajectory_skipped = True |
| continue |
| |
| if not trajectory_skipped: |
| processed_trajectories += 1 |
| |
| for base_index in range(trajectory_length): |
| all_steps.append((trajectory_id, base_index)) |
| |
| |
| print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories") |
| print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories") |
| |
| return all_steps |
|
|
| def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]: |
| """Get position and gripper values based on available columns in the dataset.""" |
| |
| action_keys = self.modality_keys.get('action', []) |
| |
| |
| delta_position_values = None |
| position_candidates = ['delta_eef_position'] |
| coordinate_candidates = ['x', 'y', 'z'] |
| |
| |
| for pos_key in position_candidates: |
| full_key = f"action.{pos_key}" |
| if full_key in action_keys: |
| try: |
| |
| le_action_cfg = self.lerobot_modality_meta.action |
| subkey = pos_key |
| if subkey in le_action_cfg: |
| le_key = le_action_cfg[subkey].original_key or subkey |
| if le_key in data.columns: |
| data_array = np.stack(data[le_key]) |
| le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end) |
| filtered_data = data_array[:, le_indices] |
| delta_position_values = filtered_data.tolist() |
| break |
| except Exception: |
| continue |
| |
| |
| if delta_position_values is None: |
| x_data, y_data, z_data = None, None, None |
| for coord in coordinate_candidates: |
| full_key = f"action.{coord}" |
| if full_key in action_keys: |
| try: |
| le_action_cfg = self.lerobot_modality_meta.action |
| if coord in le_action_cfg: |
| le_key = le_action_cfg[coord].original_key or coord |
| if le_key in data.columns: |
| data_array = np.stack(data[le_key]) |
| le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end) |
| coord_data = data_array[:, le_indices].flatten() |
| if coord == 'x': |
| x_data = coord_data |
| elif coord == 'y': |
| y_data = coord_data |
| elif coord == 'z': |
| z_data = coord_data |
| except Exception: |
| continue |
| |
| if x_data is not None and y_data is not None and z_data is not None: |
| delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist() |
| |
| if delta_position_values is None: |
| |
| if 'action.delta_eef_position' in data.columns: |
| delta_position_values = data['action.delta_eef_position'].to_numpy().tolist() |
| elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']): |
| x_vals = data['action.x'].to_numpy() |
| y_vals = data['action.y'].to_numpy() |
| z_vals = data['action.z'].to_numpy() |
| delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist() |
| else: |
| raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}") |
| |
| |
| gripper_values = None |
| gripper_candidates = ['gripper_close', 'gripper'] |
| |
| for grip_key in gripper_candidates: |
| full_key = f"action.{grip_key}" |
| if full_key in action_keys: |
| try: |
| le_action_cfg = self.lerobot_modality_meta.action |
| if grip_key in le_action_cfg: |
| le_key = le_action_cfg[grip_key].original_key or grip_key |
| if le_key in data.columns: |
| data_array = np.stack(data[le_key]) |
| le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end) |
| gripper_data = data_array[:, le_indices].flatten() |
| gripper_values = gripper_data.tolist() |
| break |
| except Exception: |
| continue |
| |
| if gripper_values is None: |
| |
| if 'action.gripper_close' in data.columns: |
| gripper_values = data['action.gripper_close'].to_numpy().tolist() |
| elif 'action.gripper' in data.columns: |
| gripper_values = data['action.gripper'].to_numpy().tolist() |
| else: |
| raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}") |
| |
| return delta_position_values, gripper_values |
|
|
| def _get_modality_keys(self) -> dict: |
| """Get the modality keys for the dataset. |
| The keys are the modality names, and the values are the keys for each modality. |
| See property `modality_keys` for the expected format. |
| """ |
| modality_keys = defaultdict(list) |
| for modality, config in self.modality_configs.items(): |
| modality_keys[modality] = config.modality_keys |
| return modality_keys |
|
|
| def _get_delta_indices(self) -> dict[str, np.ndarray]: |
| """Restructure the delta indices to use modality.key as keys instead of just the modalities.""" |
| delta_indices: dict[str, np.ndarray] = {} |
| for config in self.modality_configs.values(): |
| for key in config.modality_keys: |
| delta_indices[key] = np.array(config.delta_indices) |
| return delta_indices |
|
|
| def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata: |
| """Get the metadata for the LeRobot dataset.""" |
| modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME |
| assert ( |
| modality_meta_path.exists() |
| ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" |
| with open(modality_meta_path, "r") as f: |
| modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) |
| return modality_meta |
|
|
| def _get_lerobot_info_meta(self) -> dict: |
| """Get the metadata for the LeRobot dataset.""" |
| info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME |
| with open(info_meta_path, "r") as f: |
| info_meta = json.load(f) |
| return info_meta |
|
|
| def _get_data_path_pattern(self) -> str: |
| """Get the data path pattern for the LeRobot dataset.""" |
| return self.lerobot_info_meta["data_path"] |
|
|
| def _get_video_path_pattern(self) -> str: |
| """Get the video path pattern for the LeRobot dataset.""" |
| return self.lerobot_info_meta["video_path"] |
|
|
| def _get_chunk_size(self) -> int: |
| """Get the chunk size for the LeRobot dataset.""" |
| return self.lerobot_info_meta["chunks_size"] |
|
|
| def _get_tasks(self) -> pd.DataFrame: |
| """Get the tasks for the dataset.""" |
| if self._lerobot_version == "v2.0": |
| tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME |
| with open(tasks_path, "r") as f: |
| tasks = [json.loads(line) for line in f] |
| df = pd.DataFrame(tasks) |
| return df.set_index("task_index") |
| |
| elif self._lerobot_version == "v3.0": |
| tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME |
| df = pd.read_parquet(tasks_path) |
| df = df.reset_index() |
| df = df.rename(columns={'index': 'task'}) |
| df = df[['task_index', 'task']] |
| return df |
| def _check_integrity(self): |
| """Use the config to check if the keys are valid and detect silent data corruption.""" |
| ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n" |
|
|
| for modality_config in self.modality_configs.values(): |
| for key in modality_config.modality_keys: |
| if key == "lapa_action" or key == "dream_actions": |
| continue |
| |
| try: |
| self.lerobot_modality_meta.get_key_meta(key) |
| except Exception as e: |
| raise ValueError( |
| ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}" |
| ) |
|
|
| def set_transforms_metadata(self, metadata: DatasetMetadata): |
| """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.""" |
| self.transforms.set_metadata(metadata) |
|
|
| def set_epoch(self, epoch: int): |
| """Set the epoch for the dataset. |
| |
| Args: |
| epoch (int): The epoch to set. |
| """ |
| self.epoch = epoch |
|
|
| def __len__(self) -> int: |
| """Get the total number of data points in the dataset. |
| |
| Returns: |
| int: the total number of data points in the dataset. |
| """ |
| return len(self.all_steps) |
|
|
| def __str__(self) -> str: |
| """Get the description of the dataset.""" |
| return f"{self.dataset_name} ({len(self)} steps)" |
|
|
|
|
| def __getitem__(self, index: int) -> dict: |
| """Get the data for a single step in a trajectory. |
| |
| Args: |
| index (int): The index of the step to get. |
| |
| Returns: |
| dict: The data for the step. |
| """ |
| trajectory_id, base_index = self.all_steps[index] |
| data = self.get_step_data(trajectory_id, base_index) |
| |
| |
| images = [] |
| mid_images = [] |
| for video_key in self.modality_keys.get("video", []): |
| video_frames = data[video_key] |
| image = video_frames[0] |
| image = Image.fromarray(image).resize((224, 224)) |
| images.append(image) |
| if self.num_history_steps != 0: |
| history_index = min(self.num_history_steps - 1, len(video_frames) - 1) |
| mid_image = video_frames[history_index] |
| mid_image = Image.fromarray(mid_image).resize((224, 224)) |
| mid_images.append(mid_image) |
| |
| |
| language = data[self.modality_keys["language"][0]][0] |
| action = [] |
| for action_key in self.modality_keys["action"]: |
| action.append(data[action_key]) |
| action = np.concatenate(action, axis=1) |
| action = standardize_action_representation(action, self.tag) |
|
|
| state = [] |
| for state_key in self.modality_keys["state"]: |
| state.append(data[state_key]) |
| state = np.concatenate(state, axis=1) |
| state = standardize_state_representation(state, self.tag) |
| |
| sample = dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id) |
| if self.num_history_steps != 0: |
| sample["mid_image"] = mid_images |
| return sample |
|
|
| def get_step_data(self, trajectory_id: int, base_index: int) -> dict: |
| """Get the RAW data for a single step in a trajectory. No transforms are applied. |
| |
| Args: |
| trajectory_id (int): The name of the trajectory. |
| base_index (int): The base step index in the trajectory. |
| |
| Returns: |
| dict: The RAW data for the step. |
| |
| Example return: |
| { |
| "video": { |
| "video.image_side_0": [B, T, H, W, C], |
| "video.image_side_1": [B, T, H, W, C], |
| }, |
| "state": { |
| "state.eef_position": [B, T, state_dim], |
| "state.eef_rotation": [B, T, state_dim], |
| }, |
| "action": { |
| "action.eef_position": [B, T, action_dim], |
| "action.eef_rotation": [B, T, action_dim], |
| }, |
| } |
| """ |
| data = {} |
| |
| self.curr_traj_data = self.get_trajectory_data(trajectory_id) |
| |
| for modality in self.modality_keys: |
| |
| for key in self.modality_keys[modality]: |
| data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index) |
| return data |
|
|
| def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame: |
| """Get the data for a trajectory.""" |
| if self._lerobot_version == "v2.0": |
| |
| if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: |
| return self.curr_traj_data |
| else: |
| chunk_index = self.get_episode_chunk(trajectory_id) |
| parquet_path = self.dataset_path / self.data_path_pattern.format( |
| episode_chunk=chunk_index, episode_index=trajectory_id |
| ) |
| assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" |
| return pd.read_parquet(parquet_path) |
| elif self._lerobot_version == "v3.0": |
| return self.get_trajectory_data_lerobot_v3(trajectory_id) |
| |
| def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame: |
| """Get the data for a trajectory from lerobot v3.""" |
| if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: |
| return self.curr_traj_data |
| else: |
| chunk_index = self.get_episode_chunk(trajectory_id) |
|
|
| file_index = self.get_episode_file_index(trajectory_id) |
| |
| |
| |
| parquet_path = self.dataset_path / self.data_path_pattern.format( |
| chunk_index=chunk_index, file_index=file_index |
| ) |
| assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" |
| file_data = pd.read_parquet(parquet_path) |
| |
| |
| episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy() |
| |
| |
| if self.load_video: |
| from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get( |
| "videos/observation.images.wrist/from_timestamp", 0 |
| ) |
| episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp |
| |
| return episode_data |
|
|
|
|
| def get_trajectory_index(self, trajectory_id: int) -> int: |
| """Get the index of the trajectory in the dataset by the trajectory ID. |
| This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID. |
| |
| Args: |
| trajectory_id (str): The ID of the trajectory. |
| |
| Returns: |
| int: The index of the trajectory in the dataset. |
| """ |
| trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0] |
| if len(trajectory_indices) != 1: |
| raise ValueError( |
| f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}" |
| ) |
| return trajectory_indices[0] |
|
|
| def get_episode_chunk(self, ep_index: int) -> int: |
| """Get the chunk index for an episode index.""" |
| return ep_index // self.chunk_size |
| def get_episode_file_index(self, ep_index: int) -> int: |
| """Get the file index for an episode index.""" |
| episode_meta = self.trajectory_ids_to_metadata[ep_index] |
| return episode_meta["data/file_index"] |
| |
| def get_episode_file_from_index(self, ep_index: int) -> int: |
| """Get the file from index for an episode index.""" |
| episode_meta = self.trajectory_ids_to_metadata[ep_index] |
| return episode_meta["data/file_from_index"] |
|
|
|
|
| def retrieve_data_and_pad( |
| self, |
| array: np.ndarray, |
| step_indices: np.ndarray, |
| max_length: int, |
| padding_strategy: str = "first_last", |
| ) -> np.ndarray: |
| """Retrieve the data from the dataset and pad it if necessary. |
| Args: |
| array (np.ndarray): The array to retrieve the data from. |
| step_indices (np.ndarray): The step indices to retrieve the data for. |
| max_length (int): The maximum length of the data. |
| padding_strategy (str): The padding strategy, either "first" or "last". |
| """ |
| |
| front_padding_indices = step_indices < 0 |
| end_padding_indices = step_indices >= max_length |
| padding_positions = np.logical_or(front_padding_indices, end_padding_indices) |
| |
| |
| raw_data = array[step_indices[~padding_positions]] |
| assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}" |
| |
| if raw_data.ndim == 1: |
| expected_shape = (len(step_indices),) |
| else: |
| expected_shape = (len(step_indices), *array.shape[1:]) |
|
|
| |
| output = np.zeros(expected_shape) |
| |
| output[~padding_positions] = raw_data |
| |
| if padding_positions.any(): |
| if padding_strategy == "first_last": |
| |
| front_padding_data = array[0] |
| end_padding_data = array[-1] |
| output[front_padding_indices] = front_padding_data |
| output[end_padding_indices] = end_padding_data |
| elif padding_strategy == "zero": |
| |
| output[padding_positions] = 0 |
| else: |
| raise ValueError(f"Invalid padding strategy: {padding_strategy}") |
| return output |
|
|
| def get_video_path(self, trajectory_id: int, key: str) -> Path: |
| chunk_index = self.get_episode_chunk(trajectory_id) |
| original_key = self.lerobot_modality_meta.video[key].original_key |
| if original_key is None: |
| original_key = key |
| if self._lerobot_version == "v2.0": |
| video_filename = self.video_path_pattern.format( |
| episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key |
| ) |
| elif self._lerobot_version == "v3.0": |
| episode_meta = self.trajectory_ids_to_metadata[trajectory_id] |
| video_filename = self.video_path_pattern.format( |
| video_key=original_key, |
| chunk_index=episode_meta["data/chunk_index"], |
| file_index=episode_meta["data/file_index"], |
| ) |
| return self.dataset_path / video_filename |
|
|
| def get_video( |
| self, |
| trajectory_id: int, |
| key: str, |
| base_index: int, |
| ) -> np.ndarray: |
| """Get the video frames for a trajectory by a base index. |
| |
| Args: |
| dataset (BaseSingleDataset): The dataset to retrieve the data from. |
| trajectory_id (str): The ID of the trajectory. |
| key (str): The key of the video. |
| base_index (int): The base index of the trajectory. |
| |
| Returns: |
| np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C) |
| """ |
| |
| step_indices = self.delta_indices[key] + base_index |
| |
| |
| trajectory_index = self.get_trajectory_index(trajectory_id) |
| |
| |
| step_indices = np.maximum(step_indices, 0) |
| step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1) |
| assert key.startswith("video."), f"Video key must start with 'video.', got {key}" |
| |
| key = key.replace("video.", "") |
| video_path = self.get_video_path(trajectory_id, key) |
| |
| assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" |
| assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}" |
| timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy() |
| |
| video_timestamp = timestamp[step_indices] |
|
|
| return get_frames_by_timestamps( |
| video_path.as_posix(), |
| video_timestamp, |
| video_backend=self.video_backend, |
| video_backend_kwargs=self.video_backend_kwargs, |
| ) |
|
|
| def get_state_or_action( |
| self, |
| trajectory_id: int, |
| modality: str, |
| key: str, |
| base_index: int, |
| ) -> np.ndarray: |
| """Get the state or action data for a trajectory by a base index. |
| If the step indices are out of range, pad with the data: |
| if the data is stored in absolute format, pad with the first or last step data; |
| otherwise, pad with zero. |
| |
| Args: |
| dataset (BaseSingleDataset): The dataset to retrieve the data from. |
| trajectory_id (int): The ID of the trajectory. |
| modality (str): The modality of the data. |
| key (str): The key of the data. |
| base_index (int): The base index of the trajectory. |
| |
| Returns: |
| np.ndarray: The data for the trajectory and step indices. |
| """ |
| |
| step_indices = self.delta_indices[key] + base_index |
| |
| trajectory_index = self.get_trajectory_index(trajectory_id) |
| |
| max_length = self.trajectory_lengths[trajectory_index] |
| assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}" |
| |
| key = key.replace(modality + ".", "") |
| |
| le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality) |
| le_key = le_state_or_action_cfg[key].original_key |
| if le_key is None: |
| le_key = key |
| |
| assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" |
| assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}" |
| data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) |
| assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array" |
| le_indices = np.arange( |
| le_state_or_action_cfg[key].start, |
| le_state_or_action_cfg[key].end, |
| ) |
| data_array = data_array[:, le_indices] |
| |
| state_or_action_cfg = getattr(self.metadata.modalities, modality)[key] |
|
|
| |
| return self.retrieve_data_and_pad( |
| array=data_array, |
| step_indices=step_indices, |
| max_length=max_length, |
| padding_strategy="first_last" if state_or_action_cfg.absolute else "zero", |
| |
| ) |
|
|
| def get_language( |
| self, |
| trajectory_id: int, |
| key: str, |
| base_index: int, |
| ) -> list[str]: |
| """Get the language annotation data for a trajectory by step indices. |
| |
| Args: |
| dataset (BaseSingleDataset): The dataset to retrieve the data from. |
| trajectory_id (int): The ID of the trajectory. |
| key (str): The key of the annotation. |
| base_index (int): The base index of the trajectory. |
| |
| Returns: |
| list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings. |
| """ |
| assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" |
| |
| step_indices = self.delta_indices[key] + base_index |
| |
| trajectory_index = self.get_trajectory_index(trajectory_id) |
| |
| max_length = self.trajectory_lengths[trajectory_index] |
| |
| step_indices = np.maximum(step_indices, 0) |
| step_indices = np.minimum(step_indices, max_length - 1) |
| |
| task_indices: list[int] = [] |
| assert key.startswith( |
| "annotation." |
| ), f"Language key must start with 'annotation.', got {key}" |
| subkey = key.replace("annotation.", "") |
| annotation_meta = self.lerobot_modality_meta.annotation |
| assert annotation_meta is not None, f"Annotation metadata is None for {subkey}" |
| assert ( |
| subkey in annotation_meta |
| ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}" |
| subkey_meta = annotation_meta[subkey] |
| original_key = subkey_meta.original_key |
| if original_key is None: |
| original_key = key |
| for i in range(len(step_indices)): |
| |
| value = self.curr_traj_data[original_key].iloc[step_indices[i]] |
| task_indices.append(value if isinstance(value, (int, float)) else value.item()) |
|
|
| return self.tasks.loc[task_indices]["task"].tolist() |
|
|
| def get_data_by_modality( |
| self, |
| trajectory_id: int, |
| modality: str, |
| key: str, |
| base_index: int, |
| ): |
| """Get the data corresponding to the modality for a trajectory by a base index. |
| This method will call the corresponding helper method based on the modality. |
| See the helper methods for more details. |
| NOTE: For the language modality, the data is padded with empty strings if no matching data is found. |
| |
| Args: |
| dataset (BaseSingleDataset): The dataset to retrieve the data from. |
| trajectory_id (int): The ID of the trajectory. |
| modality (str): The modality of the data. |
| key (str): The key of the data. |
| base_index (int): The base index of the trajectory. |
| """ |
| if modality == "video": |
| return self.get_video(trajectory_id, key, base_index) |
| elif modality == "state" or modality == "action": |
| return self.get_state_or_action(trajectory_id, modality, key, base_index) |
| elif modality == "language": |
| return self.get_language(trajectory_id, key, base_index) |
| else: |
| raise ValueError(f"Invalid modality: {modality}") |
|
|
| def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None: |
| """ |
| Save dataset statistics to specified path in the required format. |
| Only includes statistics for keys that are actually used in the dataset. |
| Key order follows modality config order. |
| |
| Args: |
| save_path (Path | str): Path to save the statistics file |
| format (str): Save format, currently only supports "json" |
| """ |
| save_path = Path(save_path) |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| statistics_data = {} |
| |
| |
| used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys) |
| |
| |
| tag = self.tag |
| tag_stats = {} |
| |
| |
| if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action: |
| action_stats = self.metadata.statistics.action |
| filtered_action_stats = { |
| key: action_stats[key] |
| for key in used_action_keys |
| if key in action_stats |
| } |
| |
| if filtered_action_stats: |
| |
| combined_action_stats = combine_modality_stats(filtered_action_stats) |
| |
| |
| mask = generate_action_mask_for_used_keys( |
| self.metadata.modalities.action, filtered_action_stats.keys() |
| ) |
| combined_action_stats["mask"] = mask |
| |
| tag_stats["action"] = combined_action_stats |
| |
| |
| if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state: |
| state_stats = self.metadata.statistics.state |
| filtered_state_stats = { |
| key: state_stats[key] |
| for key in used_state_keys |
| if key in state_stats |
| } |
| |
| if filtered_state_stats: |
| combined_state_stats = combine_modality_stats(filtered_state_stats) |
| tag_stats["state"] = combined_state_stats |
| |
| |
| tag_stats["num_transitions"] = len(self) |
| tag_stats["num_trajectories"] = len(self.trajectory_ids) |
| |
| statistics_data[tag] = tag_stats |
| |
| |
| if format.lower() == "json": |
| if not str(save_path).endswith('.json'): |
| save_path = save_path.with_suffix('.json') |
| with open(save_path, 'w', encoding='utf-8') as f: |
| json.dump(statistics_data, f, indent=2, ensure_ascii=False) |
| else: |
| raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") |
| |
| print(f"Single dataset statistics saved to: {save_path}") |
| print(f"Used action keys (reordered): {list(used_action_keys)}") |
| print(f"Used state keys (reordered): {list(used_state_keys)}") |
|
|
|
|
|
|
| class MixtureSpecElement(BaseModel): |
| dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.") |
| dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.") |
| distribute_weights: bool = Field( |
| default=False, |
| description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.", |
| ) |
|
|
|
|
| |
|
|
| def combine_modality_stats(modality_stats: dict) -> dict: |
| """ |
| Combine statistics from all sub-keys under a modality. |
| |
| Args: |
| modality_stats (dict): Statistics for a modality, containing multiple sub-keys. |
| Each sub-key contains DatasetStatisticalValues object. |
| |
| Returns: |
| dict: Combined statistics |
| """ |
| combined_stats = { |
| "mean": [], |
| "std": [], |
| "max": [], |
| "min": [], |
| "q01": [], |
| "q99": [] |
| } |
| |
| |
| for subkey in modality_stats.keys(): |
| subkey_stats = modality_stats[subkey] |
| |
| |
| for stat_name in ["mean", "std", "max", "min", "q01", "q99"]: |
| stat_value = getattr(subkey_stats, stat_name) |
| if isinstance(stat_value, (list, tuple)): |
| combined_stats[stat_name].extend(stat_value) |
| else: |
| |
| if hasattr(stat_value, 'tolist'): |
| combined_stats[stat_name].extend(stat_value.tolist()) |
| else: |
| combined_stats[stat_name].append(float(stat_value)) |
| |
| return combined_stats |
|
|
| def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]: |
| """ |
| Generate mask based on action modalities, but only for used keys. |
| All dimensions are set to True so every channel is de/normalized. |
| |
| Args: |
| action_modalities (dict): Configuration information for action modalities. |
| used_action_keys_ordered: Iterable of actually used action keys in the correct order. |
| |
| Returns: |
| list[bool]: List of mask values |
| """ |
| mask = [] |
| |
| |
| for subkey in used_action_keys_ordered: |
| if subkey in action_modalities: |
| subkey_config = action_modalities[subkey] |
| |
| |
| if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0: |
| dim_count = subkey_config.shape[0] |
| else: |
| dim_count = 1 |
| |
| |
| is_gripper = "gripper" in subkey.lower() |
| |
| |
| for _ in range(dim_count): |
| mask.append(not is_gripper) |
| |
| return mask |
|
|
| def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]: |
| """Extract used action and state keys from modality configuration.""" |
| used_action_keys = [] |
| used_state_keys = [] |
| |
| |
| for action_key in modality_keys.get("action", []): |
| if action_key.startswith("action."): |
| clean_key = action_key.replace("action.", "") |
| used_action_keys.append(clean_key) |
| |
| |
| for state_key in modality_keys.get("state", []): |
| if state_key.startswith("state."): |
| clean_key = state_key.replace("state.", "") |
| used_state_keys.append(clean_key) |
| |
| return used_action_keys, used_state_keys |
|
|
|
|
| def safe_hash(input_tuple): |
| |
| tuple_string = repr(input_tuple).encode("utf-8") |
| sha256 = hashlib.sha256() |
| sha256.update(tuple_string) |
|
|
| seed = int(sha256.hexdigest(), 16) |
|
|
| return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF |
| |
|
|
| class LeRobotMixtureDataset(Dataset): |
| """ |
| A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset. |
| It is recommended to modify the single dataset class instead of this class. |
| """ |
|
|
| def __init__( |
| self, |
| data_mixture: Sequence[tuple[LeRobotSingleDataset, float]], |
| mode: str, |
| balance_dataset_weights: bool = True, |
| balance_trajectory_weights: bool = True, |
| seed: int = 42, |
| metadata_config: dict = { |
| "percentile_mixing_method": "min_max", |
| }, |
| **kwargs, |
| ): |
| """ |
| Initialize the mixture dataset. |
| |
| Args: |
| data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights. |
| mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch. |
| balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset. |
| balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting. |
| seed (int): Random seed for sampling. |
| """ |
| datasets: list[LeRobotSingleDataset] = [] |
| dataset_sampling_weights: list[float] = [] |
| for dataset, weight in data_mixture: |
| |
| if len(dataset) == 0: |
| print(f"Warning: Skipping empty dataset {dataset.dataset_name}") |
| continue |
| datasets.append(dataset) |
| dataset_sampling_weights.append(weight) |
| |
| if len(datasets) == 0: |
| raise ValueError("No valid datasets found in the mixture. All datasets are empty.") |
| |
| self.datasets = datasets |
| self.balance_dataset_weights = balance_dataset_weights |
| self.balance_trajectory_weights = balance_trajectory_weights |
| self.seed = seed |
| self.mode = mode |
|
|
| |
|
|
| |
| self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets]) |
| print(f"Dataset lengths: {self._dataset_lengths}") |
|
|
| |
| self._dataset_sampling_weights = np.array(dataset_sampling_weights) |
| |
| if self.balance_dataset_weights: |
| self._dataset_sampling_weights *= self._dataset_lengths |
| |
| |
| if np.any(self._dataset_sampling_weights <= 0): |
| print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}") |
| |
| self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8) |
| |
| |
| weights_sum = self._dataset_sampling_weights.sum() |
| if weights_sum == 0 or np.isnan(weights_sum): |
| print(f"Error: Invalid weights sum: {weights_sum}") |
| |
| self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets) |
| print(f"Fallback to equal weights") |
| else: |
| self._dataset_sampling_weights /= weights_sum |
|
|
| |
| self._trajectory_sampling_weights: list[np.ndarray] = [] |
| for i, dataset in enumerate(self.datasets): |
| trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) |
| if self.balance_trajectory_weights: |
| trajectory_sampling_weights *= dataset.trajectory_lengths |
| |
| |
| if np.any(trajectory_sampling_weights <= 0): |
| print(f"Warning: Dataset {i} has zero or negative trajectory weights") |
| trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8) |
| |
| |
| weights_sum = trajectory_sampling_weights.sum() |
| if weights_sum == 0 or np.isnan(weights_sum): |
| print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}") |
| |
| trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths) |
| else: |
| trajectory_sampling_weights /= weights_sum |
| |
| self._trajectory_sampling_weights.append(trajectory_sampling_weights) |
|
|
| |
| self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0 |
| if not np.any(self._primary_dataset_indices): |
| print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}") |
| |
| max_weight = max(dataset_sampling_weights) |
| self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight |
| print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}") |
| |
| if not np.any(self._primary_dataset_indices): |
| |
| print("Error: Still no primary dataset found. Using first dataset as primary.") |
| self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool) |
| self._primary_dataset_indices[0] = True |
|
|
| |
| self.set_epoch(0) |
|
|
| self.update_metadata(metadata_config) |
|
|
| @property |
| def dataset_lengths(self) -> np.ndarray: |
| """The lengths of each dataset.""" |
| return self._dataset_lengths |
|
|
| @property |
| def dataset_sampling_weights(self) -> np.ndarray: |
| """The sampling weights for each dataset.""" |
| return self._dataset_sampling_weights |
|
|
| @property |
| def trajectory_sampling_weights(self) -> list[np.ndarray]: |
| """The sampling weights for each trajectory in each dataset.""" |
| return self._trajectory_sampling_weights |
|
|
| @property |
| def primary_dataset_indices(self) -> np.ndarray: |
| """The indices of the primary datasets.""" |
| return self._primary_dataset_indices |
|
|
| def __str__(self) -> str: |
| dataset_descriptions = [] |
| for dataset, weight in zip(self.datasets, self.dataset_sampling_weights): |
| dataset_description = { |
| "Dataset": str(dataset), |
| "Sampling weight": float(weight), |
| } |
| dataset_descriptions.append(dataset_description) |
| return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2) |
|
|
| def set_epoch(self, epoch: int): |
| """Set the epoch for the dataset. |
| |
| Args: |
| epoch (int): The epoch to set. |
| """ |
| self.epoch = epoch |
| |
|
|
| def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]: |
| """Sample a single step from the dataset.""" |
| |
|
|
| |
| seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed)) |
| rng = np.random.default_rng(seed) |
|
|
| |
| dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights) |
| dataset = self.datasets[dataset_index] |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| single_step_index = rng.choice(len(dataset.all_steps)) |
| trajectory_id, base_index = dataset.all_steps[single_step_index] |
| return dataset, trajectory_id, base_index |
|
|
| def __getitem__(self, index: int) -> dict: |
| """Get the data for a single trajectory and start index. |
| |
| Args: |
| index (int): The index of the trajectory to get. |
| |
| Returns: |
| dict: The data for the trajectory and start index. |
| """ |
| max_retries = 10 |
| last_exception = None |
| |
| for attempt in range(max_retries): |
| try: |
| dataset, trajectory_name, step = self.sample_step(index) |
| data_raw = dataset.get_step_data(trajectory_name, step) |
| data = dataset.transforms(data_raw) |
| |
| |
| images = [] |
| mid_images = [] |
| num_history_steps = int(getattr(dataset, "num_history_steps", 0) or 0) |
| for video_key in dataset.modality_keys.get("video", []): |
| video_frames = data[video_key] |
| image = video_frames[0] |
| image = Image.fromarray(image).resize((224, 224)) |
| images.append(image) |
| if num_history_steps != 0: |
| history_index = min(num_history_steps - 1, len(video_frames) - 1) |
| mid_image = video_frames[history_index] |
| mid_image = Image.fromarray(mid_image).resize((224, 224)) |
| mid_images.append(mid_image) |
| |
| |
| language = data[dataset.modality_keys["language"][0]][0] |
| action = [] |
| for action_key in dataset.modality_keys["action"]: |
| action.append(data[action_key]) |
| action = np.concatenate(action, axis=1).astype(np.float16) |
| action = standardize_action_representation(action, dataset.tag) |
| |
| state = [] |
| for state_key in dataset.modality_keys["state"]: |
| state.append(data[state_key]) |
| state = np.concatenate(state, axis=1).astype(np.float16) |
| state = standardize_state_representation(state, dataset.tag) |
| |
| sample = dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id) |
| if num_history_steps != 0: |
| sample["mid_image"] = mid_images |
| return sample |
| |
| except Exception as e: |
| last_exception = e |
| if attempt < max_retries - 1: |
| |
| print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}") |
| print(f"Retrying with new sample...") |
| |
| |
| index = random.randint(0, len(self) - 1) |
| else: |
| |
| print(f"All {max_retries} attempts failed for index {index}") |
| print(f"Last error: {last_exception}") |
| |
| raise last_exception |
|
|
| def __len__(self) -> int: |
| """Get the length of a single epoch in the mixture. |
| |
| Returns: |
| int: The length of a single epoch in the mixture. |
| """ |
| |
| if len(self.datasets) == 0: |
| return 0 |
| |
| |
| if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)): |
| print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}") |
| |
| valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths)) |
| if not np.any(valid_indices): |
| print("Error: All datasets have zero or NaN length") |
| return 0 |
| else: |
| valid_indices = np.ones(len(self.datasets), dtype=bool) |
| |
| |
| if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)): |
| print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}") |
| |
| valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights)) |
| valid_indices = valid_indices & valid_weights |
| if not np.any(valid_indices): |
| print("Error: All sampling weights are zero or NaN") |
| return 0 |
| |
| |
| primary_and_valid = self.primary_dataset_indices & valid_indices |
| if not np.any(primary_and_valid): |
| print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}") |
| |
| if np.any(valid_indices): |
| max_length = self.dataset_lengths[valid_indices].max() |
| print(f"Fallback: Using maximum dataset length: {max_length}") |
| return int(max_length) |
| else: |
| return 0 |
| |
| |
| ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid] |
| |
| |
| if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)): |
| print(f"Warning: Found NaN or inf in ratios: {ratios}") |
| print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}") |
| print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}") |
| |
| valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)] |
| if len(valid_ratios) == 0: |
| print("Error: All ratios are NaN or inf") |
| return 0 |
| max_ratio = valid_ratios.max() |
| else: |
| max_ratio = ratios.max() |
| |
| result = int(max_ratio) |
| if result == 0: |
| print(f"Warning: Dataset mixture length is 0") |
| return result |
|
|
| @staticmethod |
| def compute_overall_statistics( |
| per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]], |
| dataset_sampling_weights: list[float] | np.ndarray, |
| percentile_mixing_method: str = "weighted_average", |
| ) -> dict[str, dict[str, list[float]]]: |
| """ |
| Computes overall statistics from per-task statistics using dataset sample weights. |
| |
| Args: |
| per_task_stats: List of per-task statistics. |
| Example format of one element in the per-task statistics list: |
| { |
| "state.gripper": { |
| "min": [...], |
| "max": [...], |
| "mean": [...], |
| "std": [...], |
| "q01": [...], |
| "q99": [...], |
| }, |
| ... |
| } |
| dataset_sampling_weights: List of sample weights for each task. |
| percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std". |
| |
| Returns: |
| A dict of overall statistics per modality. |
| """ |
| |
| dataset_sampling_weights = np.array(dataset_sampling_weights) |
| normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum() |
|
|
| |
| overall_stats: dict[str, dict[str, list[float]]] = {} |
|
|
| |
| modality_keys = per_task_stats[0].keys() |
|
|
| for modality in modality_keys: |
| |
| num_dims = len(per_task_stats[0][modality]["mean"]) |
|
|
| |
| weighted_means = np.zeros(num_dims) |
| weighted_squares = np.zeros(num_dims) |
|
|
| |
| min_list = [] |
| max_list = [] |
| q01_list = [] |
| q99_list = [] |
|
|
| for task_idx, task_stats in enumerate(per_task_stats): |
| w_i = normalized_weights[task_idx] |
| stats = task_stats[modality] |
| means = np.array(stats["mean"]) |
| stds = np.array(stats["std"]) |
|
|
| |
| weighted_means += w_i * means |
| weighted_squares += w_i * (stds**2 + means**2) |
|
|
| |
| min_list.append(stats["min"]) |
| max_list.append(stats["max"]) |
| q01_list.append(stats["q01"]) |
| q99_list.append(stats["q99"]) |
|
|
| |
| overall_mean = weighted_means.tolist() |
|
|
| |
| overall_variance = weighted_squares - weighted_means**2 |
| overall_std = np.sqrt(overall_variance).tolist() |
|
|
| |
| overall_min = np.min(np.array(min_list), axis=0).tolist() |
| overall_max = np.max(np.array(max_list), axis=0).tolist() |
|
|
| |
| |
| q01_array = np.array(q01_list) |
| q99_array = np.array(q99_list) |
| if percentile_mixing_method == "weighted_average": |
| weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist() |
| weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist() |
| |
| |
| |
| |
| |
| elif percentile_mixing_method == "min_max": |
| weighted_q01 = np.min(q01_array, axis=0).tolist() |
| weighted_q99 = np.max(q99_array, axis=0).tolist() |
| else: |
| raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}") |
|
|
| |
| overall_stats[modality] = { |
| "min": overall_min, |
| "max": overall_max, |
| "mean": overall_mean, |
| "std": overall_std, |
| "q01": weighted_q01, |
| "q99": weighted_q99, |
| } |
|
|
| return overall_stats |
|
|
| @staticmethod |
| def merge_metadata( |
| metadatas: list[DatasetMetadata], |
| dataset_sampling_weights: list[float], |
| percentile_mixing_method: str, |
| ) -> DatasetMetadata: |
| """Merge multiple metadata into one.""" |
| |
| metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas] |
| |
| merged_metadata = {} |
|
|
| |
| assert all( |
| metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas |
| ), "All metadata must have the same embodiment tag" |
| merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag |
|
|
| |
| dataset_statistics = {} |
| dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics( |
| per_task_stats=[m["statistics"]["state"] for m in metadata_dicts], |
| dataset_sampling_weights=dataset_sampling_weights, |
| percentile_mixing_method=percentile_mixing_method, |
| ) |
| dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics( |
| per_task_stats=[m["statistics"]["action"] for m in metadata_dicts], |
| dataset_sampling_weights=dataset_sampling_weights, |
| percentile_mixing_method=percentile_mixing_method, |
| ) |
| merged_metadata["statistics"] = dataset_statistics |
|
|
| |
| modality_configs = defaultdict(set) |
| for metadata in metadata_dicts: |
| for modality, configs in metadata["modalities"].items(): |
| modality_configs[modality].add(json.dumps(configs)) |
| merged_metadata["modalities"] = {} |
| for modality, configs in modality_configs.items(): |
| |
| assert ( |
| len(configs) == 1 |
| ), f"Multiple modality configs for modality {modality}: {list(configs)}" |
| merged_metadata["modalities"][modality] = json.loads(configs.pop()) |
|
|
| return DatasetMetadata.model_validate(merged_metadata) |
|
|
| def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None: |
| """ |
| Merge multiple metadatas into one and set the transforms with the merged metadata. |
| |
| Args: |
| metadata_config (dict): Configuration for the metadata. |
| "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max". |
| weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets. |
| min_max: Use the min of the 1st percentile and max of the 99th percentile. |
| """ |
| |
| if cached_statistics_path is not None: |
| try: |
| cached_stats = self.load_merged_statistics(cached_statistics_path) |
| self.apply_cached_statistics(cached_stats) |
| return |
| except (FileNotFoundError, KeyError, ValidationError) as e: |
| print(f"Failed to load cached statistics: {e}") |
| print("Falling back to computing statistics from scratch...") |
|
|
| self.tag = EmbodimentTag.NEW_EMBODIMENT.value |
| self.merged_metadata: dict[str, DatasetMetadata] = {} |
| |
| all_metadatas: dict[str, list[DatasetMetadata]] = {} |
| for dataset in self.datasets: |
| if dataset.tag not in all_metadatas: |
| all_metadatas[dataset.tag] = [] |
| all_metadatas[dataset.tag].append(dataset.metadata) |
| for tag, metadatas in all_metadatas.items(): |
| self.merged_metadata[tag] = self.merge_metadata( |
| metadatas=metadatas, |
| dataset_sampling_weights=self.dataset_sampling_weights.tolist(), |
| percentile_mixing_method=metadata_config["percentile_mixing_method"], |
| ) |
| for dataset in self.datasets: |
| dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) |
|
|
| def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None: |
| """ |
| Save merged dataset statistics to specified path in the required format. |
| Only includes statistics for keys that are actually used in the datasets. |
| Key order follows each tag's modality config order. |
| |
| Args: |
| save_path (Path | str): Path to save the statistics file |
| format (str): Save format, currently only supports "json" |
| """ |
| save_path = Path(save_path) |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| statistics_data = {} |
| |
| |
| tag_to_used_action_keys = {} |
| tag_to_used_state_keys = {} |
| for dataset in self.datasets: |
| if dataset.tag in tag_to_used_action_keys: |
| continue |
| used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys) |
| tag_to_used_action_keys[dataset.tag] = used_action_keys |
| tag_to_used_state_keys[dataset.tag] = used_state_keys |
| |
| |
| for tag, merged_metadata in self.merged_metadata.items(): |
| tag_stats = {} |
| |
| |
| if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action: |
| action_stats = merged_metadata.statistics.action |
| |
| used_action_keys = tag_to_used_action_keys.get(tag, []) |
| filtered_action_stats = { |
| key: action_stats[key] |
| for key in used_action_keys |
| if key in action_stats |
| } |
| |
| if filtered_action_stats: |
| combined_action_stats = combine_modality_stats(filtered_action_stats) |
| |
| mask = generate_action_mask_for_used_keys( |
| merged_metadata.modalities.action, filtered_action_stats.keys() |
| ) |
| combined_action_stats["mask"] = mask |
| |
| tag_stats["action"] = combined_action_stats |
| |
| |
| if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state: |
| state_stats = merged_metadata.statistics.state |
| |
| used_state_keys = tag_to_used_state_keys.get(tag, []) |
| filtered_state_stats = { |
| key: state_stats[key] |
| for key in used_state_keys |
| if key in state_stats |
| } |
| |
| if filtered_state_stats: |
| combined_state_stats = combine_modality_stats(filtered_state_stats) |
| tag_stats["state"] = combined_state_stats |
| |
| |
| tag_stats.update(self._get_dataset_counts(tag)) |
| |
| statistics_data[tag] = tag_stats |
| |
| |
| if format.lower() == "json": |
| if not str(save_path).endswith('.json'): |
| save_path = save_path.with_suffix('.json') |
| with open(save_path, 'w', encoding='utf-8') as f: |
| json.dump(statistics_data, f, indent=2, ensure_ascii=False) |
| else: |
| raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") |
| |
| print(f"Merged dataset statistics saved to: {save_path}") |
| print(f"Used action keys by tag: {tag_to_used_action_keys}") |
| print(f"Used state keys by tag: {tag_to_used_state_keys}") |
|
|
|
|
| def _combine_modality_stats(self, modality_stats: dict) -> dict: |
| """Backward compatibility wrapper.""" |
| return combine_modality_stats(modality_stats) |
|
|
| def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]: |
| """Backward compatibility wrapper.""" |
| return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered) |
|
|
| def _get_dataset_counts(self, tag: str) -> dict: |
| """ |
| Get dataset count information for specified tag. |
| |
| Args: |
| tag (str): embodiment tag |
| |
| Returns: |
| dict: Dictionary containing num_transitions and num_trajectories |
| """ |
| num_transitions = 0 |
| num_trajectories = 0 |
| |
| |
| for dataset in self.datasets: |
| if dataset.tag == tag: |
| num_transitions += len(dataset) |
| num_trajectories += len(dataset.trajectory_ids) |
| |
| return { |
| "num_transitions": num_transitions, |
| "num_trajectories": num_trajectories |
| } |
|
|
| @classmethod |
| def load_merged_statistics(cls, load_path: Path | str) -> dict: |
| """ |
| Load merged dataset statistics from file. |
| |
| Args: |
| load_path (Path | str): Path to the statistics file |
| |
| Returns: |
| dict: Dictionary containing merged statistics |
| """ |
| load_path = Path(load_path) |
| if not load_path.exists(): |
| raise FileNotFoundError(f"Statistics file not found: {load_path}") |
| |
| if load_path.suffix.lower() == '.json': |
| with open(load_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| elif load_path.suffix.lower() == '.pkl': |
| import pickle |
| with open(load_path, 'rb') as f: |
| return pickle.load(f) |
| else: |
| raise ValueError(f"Unsupported file format: {load_path.suffix}") |
|
|
| def apply_cached_statistics(self, cached_statistics: dict) -> None: |
| """ |
| Apply cached statistics to avoid recomputation. |
| |
| Args: |
| cached_statistics (dict): Statistics loaded from file |
| """ |
| |
| if "metadata" in cached_statistics: |
| cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"]) |
| current_dataset_names = set(dataset.dataset_name for dataset in self.datasets) |
| |
| if cached_dataset_names != current_dataset_names: |
| print("Warning: Cached statistics dataset names don't match current datasets.") |
| print(f"Cached: {cached_dataset_names}") |
| print(f"Current: {current_dataset_names}") |
| return |
| |
| |
| self.merged_metadata = {} |
| for tag, stats_data in cached_statistics.items(): |
| if tag == "metadata": |
| continue |
| |
| |
| metadata_dict = { |
| "embodiment_tag": tag, |
| "statistics": { |
| "action": {}, |
| "state": {} |
| }, |
| "modalities": {} |
| } |
| |
| |
| if "action" in stats_data: |
| action_data = stats_data["action"] |
| |
| metadata_dict["statistics"]["action"] = action_data |
| |
| |
| if "state" in stats_data: |
| state_data = stats_data["state"] |
| metadata_dict["statistics"]["state"] = state_data |
| |
| self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict) |
| |
| |
| for dataset in self.datasets: |
| if dataset.tag in self.merged_metadata: |
| dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) |
| |
| print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.") |
|
|
|
|