| """Compute normalization statistics for a config. |
| |
| This script is used to compute the normalization statistics for a given config. It |
| will compute the mean and standard deviation of the data in the dataset and save it |
| to the config assets directory. |
| """ |
|
|
| import json |
| import pathlib |
|
|
| import numpy as np |
| import polars as pl |
| import tqdm |
| import tyro |
|
|
| import openpi.models.model as _model |
| import openpi.shared.normalize as normalize |
| import openpi.training.config as _config |
| import openpi.training.data_loader as _data_loader |
| import openpi.transforms as transforms |
|
|
|
|
| class RemoveStrings(transforms.DataTransformFn): |
| def __call__(self, x: dict) -> dict: |
| return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} |
|
|
|
|
| def create_torch_dataloader( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| model_config: _model.BaseModelConfig, |
| num_workers: int, |
| max_frames: int | None = None, |
| ) -> tuple[_data_loader.Dataset, int]: |
| if data_config.repo_id is None: |
| raise ValueError("Data config must have a repo_id") |
| dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config) |
| dataset = _data_loader.TransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| |
| RemoveStrings(), |
| ], |
| ) |
| if max_frames is not None and max_frames < len(dataset): |
| num_batches = max_frames // batch_size |
| shuffle = True |
| else: |
| num_batches = len(dataset) // batch_size |
| shuffle = False |
| data_loader = _data_loader.TorchDataLoader( |
| dataset, |
| local_batch_size=batch_size, |
| num_workers=num_workers, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| ) |
| return data_loader, num_batches |
|
|
|
|
| def _local_lerobot_episode_paths(repo_id: str) -> list[pathlib.Path]: |
| root = pathlib.Path(repo_id) |
| paths = sorted((root / "data").glob("chunk-*/episode_*.parquet")) |
| if not paths: |
| raise FileNotFoundError(f"No parquet episodes found under {root / 'data'}") |
| return paths |
|
|
|
|
| def _local_lerobot_total_frames(repo_id: str) -> int: |
| info_path = pathlib.Path(repo_id) / "meta" / "info.json" |
| with info_path.open() as f: |
| return int(json.load(f)["total_frames"]) |
|
|
|
|
| def _stack_list_column(frame: pl.DataFrame, column: str) -> np.ndarray: |
| return np.asarray(frame[column].to_list(), dtype=np.float32) |
|
|
|
|
| def _resolve_state_action_columns(path: pathlib.Path) -> tuple[str, str]: |
| schema = pl.read_parquet_schema(path) |
| state_column = next((name for name in ("observation.state", "state") if name in schema), None) |
| action_column = next((name for name in ("action", "actions") if name in schema), None) |
| if state_column is None or action_column is None: |
| raise ValueError( |
| f"Could not find state/action columns in {path}. " |
| f"Available columns: {', '.join(schema.keys())}" |
| ) |
| return state_column, action_column |
|
|
|
|
| def _action_chunks(actions: np.ndarray, num_starts: int, action_horizon: int) -> np.ndarray: |
| starts = np.arange(num_starts)[:, None] |
| offsets = np.arange(action_horizon)[None, :] |
| indices = np.minimum(starts + offsets, len(actions) - 1) |
| return actions[indices] |
|
|
|
|
| def _can_use_fast_local_lerobot_stats( |
| config: _config.TrainConfig, |
| data_config: _config.DataConfig, |
| max_frames: int | None, |
| ) -> bool: |
| if max_frames is not None: |
| return False |
| if not isinstance(config.data, _config.LeRobotVariousSpeedLiberoDataConfig): |
| return False |
| if data_config.online_sliding_chunks: |
| return False |
| if config.data.extra_delta_transform: |
| return False |
| return data_config.repo_id is not None and pathlib.Path(data_config.repo_id).is_dir() |
|
|
|
|
| def _can_use_fast_online_sliding_lerobot_stats( |
| config: _config.TrainConfig, |
| data_config: _config.DataConfig, |
| max_frames: int | None, |
| ) -> bool: |
| if max_frames is not None: |
| return False |
| if not isinstance(config.data, _config.LeRobotVariousSpeedLiberoDataConfig): |
| return False |
| if not data_config.online_sliding_chunks: |
| return False |
| if config.data.extra_delta_transform: |
| return False |
| return data_config.repo_id is not None and pathlib.Path(data_config.repo_id).is_dir() |
|
|
|
|
| def _compute_fast_local_lerobot_stats( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| ) -> tuple[dict[str, normalize.RunningStats], int]: |
| """Compute stats from local LeRobot parquet data without decoding videos.""" |
| if data_config.repo_id is None: |
| raise ValueError("Data config must have a repo_id") |
|
|
| total_frames = _local_lerobot_total_frames(data_config.repo_id) |
| usable_frames = (total_frames // batch_size) * batch_size |
| remaining = usable_frames |
| stats = {key: normalize.RunningStats() for key in ["state", "actions"]} |
|
|
| paths = _local_lerobot_episode_paths(data_config.repo_id) |
| state_column, action_column = _resolve_state_action_columns(paths[0]) |
|
|
| for path in tqdm.tqdm( |
| paths, |
| desc="Computing stats from parquet", |
| ): |
| if remaining <= 0: |
| break |
|
|
| frame = pl.read_parquet(path, columns=[state_column, action_column]) |
| num_starts = min(len(frame), remaining) |
| if num_starts <= 0: |
| continue |
|
|
| states = _stack_list_column(frame, state_column) |
| actions = _stack_list_column(frame, action_column) |
| stats["state"].update(states[:num_starts]) |
| stats["actions"].update(_action_chunks(actions, num_starts, action_horizon)) |
| remaining -= num_starts |
|
|
| return stats, usable_frames // batch_size |
|
|
|
|
| def _reuse_source_one_x_norm_stats( |
| norm_stats: dict[str, normalize.NormStats], |
| _speeds: tuple[float, ...] | list[float], |
| ) -> dict[str, normalize.NormStats]: |
| """Return source 1.0x stats unchanged for online sliding normalization.""" |
| return norm_stats |
|
|
|
|
| def main( |
| config_name: str, |
| max_frames: int | None = None, |
| *, |
| fast_local_lerobot: bool = True, |
| repo_id: str | None = None, |
| asset_id: str | None = None, |
| online_sliding_chunks: bool = False, |
| online_sliding_speeds: tuple[float, ...] = (), |
| online_sliding_cache_size: int | None = None, |
| ): |
| """Compute norm stats. |
| |
| Optional overrides ``repo_id`` and ``asset_id`` allow sweep scripts to reuse |
| a single TrainConfig name across multiple datasets / asset directories |
| without registering one config per ablation. |
| """ |
| import dataclasses as _dc |
|
|
| config = _config.get_config(config_name) |
| if repo_id is not None or asset_id is not None: |
| new_data = config.data |
| if repo_id is not None: |
| new_data = _dc.replace(new_data, repo_id=repo_id) |
| if asset_id is not None: |
| new_assets = _dc.replace(new_data.assets, asset_id=asset_id) |
| new_data = _dc.replace(new_data, assets=new_assets) |
| config = _dc.replace(config, data=new_data) |
| if online_sliding_chunks or online_sliding_speeds or online_sliding_cache_size is not None: |
| if not isinstance(config.data, _config.LeRobotVariousSpeedLiberoDataConfig): |
| raise ValueError("online sliding overrides require LeRobotVariousSpeedLiberoDataConfig") |
| new_data = config.data |
| if online_sliding_chunks: |
| new_data = _dc.replace(new_data, online_sliding_chunks=True) |
| if online_sliding_speeds: |
| new_data = _dc.replace(new_data, online_sliding_speeds=online_sliding_speeds) |
| if online_sliding_cache_size is not None: |
| new_data = _dc.replace(new_data, online_sliding_cache_size=online_sliding_cache_size) |
| config = _dc.replace(config, data=new_data) |
| data_config = config.data.create(config.assets_dirs, config.model) |
|
|
| keys = ["state", "actions"] |
| use_source_one_x_for_online_sliding = data_config.online_sliding_chunks |
| if use_source_one_x_for_online_sliding: |
| if max_frames is not None: |
| raise ValueError("online sliding norm stats reuse source 1.0x stats and do not support max_frames.") |
| if not _can_use_fast_online_sliding_lerobot_stats(config, data_config, max_frames): |
| raise ValueError( |
| "online sliding norm stats reuse source 1.0x local LeRobot parquet stats; " |
| "repo_id must be a local LeRobot directory and extra_delta_transform must be False." |
| ) |
| stats, num_batches = _compute_fast_local_lerobot_stats( |
| data_config, |
| config.model.action_horizon, |
| config.batch_size, |
| ) |
| elif fast_local_lerobot and _can_use_fast_local_lerobot_stats(config, data_config, max_frames): |
| stats, num_batches = _compute_fast_local_lerobot_stats( |
| data_config, |
| config.model.action_horizon, |
| config.batch_size, |
| ) |
| else: |
| data_loader, num_batches = create_torch_dataloader( |
| data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames |
| ) |
| stats = {key: normalize.RunningStats() for key in keys} |
| for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): |
| for key in keys: |
| stats[key].update(np.asarray(batch[key])) |
|
|
| norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} |
| if use_source_one_x_for_online_sliding: |
| norm_stats = _reuse_source_one_x_norm_stats(norm_stats, tuple(data_config.online_sliding_speeds)) |
| print( |
| "Using source 1.0x norm stats for online sliding: raw LeRobot parquet state/action stats " |
| "are saved unchanged; online_sliding_speeds are ignored for normalization." |
| ) |
|
|
| if data_config.asset_id is None: |
| raise ValueError("Data config must have an asset_id") |
| output_path = config.assets_dirs / data_config.asset_id |
| print(f"Writing stats to: {output_path}") |
| normalize.save(output_path, norm_stats) |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|