| |
| """ |
| Compute normalization statistics (mean/std) for state and action across the filtered dataset. |
| Only reads parquet files — no video decoding, so it's fast. |
| """ |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
|
|
|
|
| def compute_stats(data_root: Path, index_path: Path) -> dict: |
| with open(index_path) as f: |
| index = json.load(f) |
|
|
| |
| episode_set = set() |
| for ep in index["episodes"]: |
| episode_set.add((ep["dataset"], ep["episode_index"])) |
|
|
| print(f"Computing stats from {len(episode_set)} episodes...") |
|
|
| |
| state_sum = np.zeros(6, dtype=np.float64) |
| state_sq_sum = np.zeros(6, dtype=np.float64) |
| action_sum = np.zeros(6, dtype=np.float64) |
| action_sq_sum = np.zeros(6, dtype=np.float64) |
| n_state = 0 |
| n_action = 0 |
|
|
| start = time.time() |
| for i, (dataset, ep_idx) in enumerate(sorted(episode_set)): |
| parquet_path = data_root / dataset / f"data/chunk-000/episode_{ep_idx:06d}.parquet" |
| if not parquet_path.exists(): |
| continue |
|
|
| df = pd.read_parquet(parquet_path) |
|
|
| states = np.stack(df["observation.state"].values).astype(np.float64) |
| actions = np.stack(df["action"].values).astype(np.float64) |
|
|
| state_sum += states.sum(axis=0) |
| state_sq_sum += (states ** 2).sum(axis=0) |
| n_state += len(states) |
|
|
| action_sum += actions.sum(axis=0) |
| action_sq_sum += (actions ** 2).sum(axis=0) |
| n_action += len(actions) |
|
|
| if (i + 1) % 1000 == 0: |
| elapsed = time.time() - start |
| rate = (i + 1) / elapsed |
| eta = (len(episode_set) - i - 1) / rate |
| print(f" [{i+1}/{len(episode_set)}] {rate:.0f} eps/s, ETA: {eta:.0f}s") |
|
|
| state_mean = state_sum / n_state |
| state_std = np.sqrt(state_sq_sum / n_state - state_mean ** 2) |
|
|
| action_mean = action_sum / n_action |
| action_std = np.sqrt(action_sq_sum / n_action - action_mean ** 2) |
|
|
| elapsed = time.time() - start |
| print(f"Done in {elapsed:.1f}s ({n_state:,} state frames, {n_action:,} action frames)") |
|
|
| print(f"\nState mean: {state_mean}") |
| print(f"State std: {state_std}") |
| print(f"Action mean: {action_mean}") |
| print(f"Action std: {action_std}") |
|
|
| stats = { |
| "observation.state": { |
| "mean": state_mean.tolist(), |
| "std": state_std.tolist(), |
| }, |
| "action": { |
| "mean": action_mean.tolist(), |
| "std": action_std.tolist(), |
| }, |
| } |
| return stats |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-root", type=Path, default=Path.home() / "lap" / "community_dataset_v3") |
| parser.add_argument("--index", type=Path, default=Path(__file__).parent / "filtered_index.json") |
| parser.add_argument("--output", type=Path, default=Path(__file__).parent / "norm_stats.json") |
| args = parser.parse_args() |
|
|
| stats = compute_stats(args.data_root, args.index) |
|
|
| with open(args.output, "w") as f: |
| json.dump(stats, f, indent=2) |
| print(f"\nSaved to {args.output}") |
|
|