pi05-so100-diverse / compute_stats.py
justinstrong's picture
Upload folder using huggingface_hub
cd604b4 verified
#!/usr/bin/env python3
"""
Compute normalization statistics (mean/std) for state and action across the filtered dataset.
Only reads parquet files — no video decoding, so it's fast.
"""
import argparse
import json
import time
from pathlib import Path
import numpy as np
import pandas as pd
def compute_stats(data_root: Path, index_path: Path) -> dict:
with open(index_path) as f:
index = json.load(f)
# Collect all unique (dataset, episode) pairs
episode_set = set()
for ep in index["episodes"]:
episode_set.add((ep["dataset"], ep["episode_index"]))
print(f"Computing stats from {len(episode_set)} episodes...")
# Online mean/variance computation (Welford's algorithm)
state_sum = np.zeros(6, dtype=np.float64)
state_sq_sum = np.zeros(6, dtype=np.float64)
action_sum = np.zeros(6, dtype=np.float64)
action_sq_sum = np.zeros(6, dtype=np.float64)
n_state = 0
n_action = 0
start = time.time()
for i, (dataset, ep_idx) in enumerate(sorted(episode_set)):
parquet_path = data_root / dataset / f"data/chunk-000/episode_{ep_idx:06d}.parquet"
if not parquet_path.exists():
continue
df = pd.read_parquet(parquet_path)
states = np.stack(df["observation.state"].values).astype(np.float64)
actions = np.stack(df["action"].values).astype(np.float64)
state_sum += states.sum(axis=0)
state_sq_sum += (states ** 2).sum(axis=0)
n_state += len(states)
action_sum += actions.sum(axis=0)
action_sq_sum += (actions ** 2).sum(axis=0)
n_action += len(actions)
if (i + 1) % 1000 == 0:
elapsed = time.time() - start
rate = (i + 1) / elapsed
eta = (len(episode_set) - i - 1) / rate
print(f" [{i+1}/{len(episode_set)}] {rate:.0f} eps/s, ETA: {eta:.0f}s")
state_mean = state_sum / n_state
state_std = np.sqrt(state_sq_sum / n_state - state_mean ** 2)
action_mean = action_sum / n_action
action_std = np.sqrt(action_sq_sum / n_action - action_mean ** 2)
elapsed = time.time() - start
print(f"Done in {elapsed:.1f}s ({n_state:,} state frames, {n_action:,} action frames)")
print(f"\nState mean: {state_mean}")
print(f"State std: {state_std}")
print(f"Action mean: {action_mean}")
print(f"Action std: {action_std}")
stats = {
"observation.state": {
"mean": state_mean.tolist(),
"std": state_std.tolist(),
},
"action": {
"mean": action_mean.tolist(),
"std": action_std.tolist(),
},
}
return stats
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", type=Path, default=Path.home() / "lap" / "community_dataset_v3")
parser.add_argument("--index", type=Path, default=Path(__file__).parent / "filtered_index.json")
parser.add_argument("--output", type=Path, default=Path(__file__).parent / "norm_stats.json")
args = parser.parse_args()
stats = compute_stats(args.data_root, args.index)
with open(args.output, "w") as f:
json.dump(stats, f, indent=2)
print(f"\nSaved to {args.output}")