#!/usr/bin/env python3 """ Compute normalization statistics (mean/std) for state and action across the filtered dataset. Only reads parquet files — no video decoding, so it's fast. """ import argparse import json import time from pathlib import Path import numpy as np import pandas as pd def compute_stats(data_root: Path, index_path: Path) -> dict: with open(index_path) as f: index = json.load(f) # Collect all unique (dataset, episode) pairs episode_set = set() for ep in index["episodes"]: episode_set.add((ep["dataset"], ep["episode_index"])) print(f"Computing stats from {len(episode_set)} episodes...") # Online mean/variance computation (Welford's algorithm) state_sum = np.zeros(6, dtype=np.float64) state_sq_sum = np.zeros(6, dtype=np.float64) action_sum = np.zeros(6, dtype=np.float64) action_sq_sum = np.zeros(6, dtype=np.float64) n_state = 0 n_action = 0 start = time.time() for i, (dataset, ep_idx) in enumerate(sorted(episode_set)): parquet_path = data_root / dataset / f"data/chunk-000/episode_{ep_idx:06d}.parquet" if not parquet_path.exists(): continue df = pd.read_parquet(parquet_path) states = np.stack(df["observation.state"].values).astype(np.float64) actions = np.stack(df["action"].values).astype(np.float64) state_sum += states.sum(axis=0) state_sq_sum += (states ** 2).sum(axis=0) n_state += len(states) action_sum += actions.sum(axis=0) action_sq_sum += (actions ** 2).sum(axis=0) n_action += len(actions) if (i + 1) % 1000 == 0: elapsed = time.time() - start rate = (i + 1) / elapsed eta = (len(episode_set) - i - 1) / rate print(f" [{i+1}/{len(episode_set)}] {rate:.0f} eps/s, ETA: {eta:.0f}s") state_mean = state_sum / n_state state_std = np.sqrt(state_sq_sum / n_state - state_mean ** 2) action_mean = action_sum / n_action action_std = np.sqrt(action_sq_sum / n_action - action_mean ** 2) elapsed = time.time() - start print(f"Done in {elapsed:.1f}s ({n_state:,} state frames, {n_action:,} action frames)") print(f"\nState mean: {state_mean}") print(f"State std: {state_std}") print(f"Action mean: {action_mean}") print(f"Action std: {action_std}") stats = { "observation.state": { "mean": state_mean.tolist(), "std": state_std.tolist(), }, "action": { "mean": action_mean.tolist(), "std": action_std.tolist(), }, } return stats if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-root", type=Path, default=Path.home() / "lap" / "community_dataset_v3") parser.add_argument("--index", type=Path, default=Path(__file__).parent / "filtered_index.json") parser.add_argument("--output", type=Path, default=Path(__file__).parent / "norm_stats.json") args = parser.parse_args() stats = compute_stats(args.data_root, args.index) with open(args.output, "w") as f: json.dump(stats, f, indent=2) print(f"\nSaved to {args.output}")