File size: 3,219 Bytes
cd604b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
"""
Compute normalization statistics (mean/std) for state and action across the filtered dataset.
Only reads parquet files — no video decoding, so it's fast.
"""

import argparse
import json
import time
from pathlib import Path

import numpy as np
import pandas as pd


def compute_stats(data_root: Path, index_path: Path) -> dict:
    with open(index_path) as f:
        index = json.load(f)

    # Collect all unique (dataset, episode) pairs
    episode_set = set()
    for ep in index["episodes"]:
        episode_set.add((ep["dataset"], ep["episode_index"]))

    print(f"Computing stats from {len(episode_set)} episodes...")

    # Online mean/variance computation (Welford's algorithm)
    state_sum = np.zeros(6, dtype=np.float64)
    state_sq_sum = np.zeros(6, dtype=np.float64)
    action_sum = np.zeros(6, dtype=np.float64)
    action_sq_sum = np.zeros(6, dtype=np.float64)
    n_state = 0
    n_action = 0

    start = time.time()
    for i, (dataset, ep_idx) in enumerate(sorted(episode_set)):
        parquet_path = data_root / dataset / f"data/chunk-000/episode_{ep_idx:06d}.parquet"
        if not parquet_path.exists():
            continue

        df = pd.read_parquet(parquet_path)

        states = np.stack(df["observation.state"].values).astype(np.float64)
        actions = np.stack(df["action"].values).astype(np.float64)

        state_sum += states.sum(axis=0)
        state_sq_sum += (states ** 2).sum(axis=0)
        n_state += len(states)

        action_sum += actions.sum(axis=0)
        action_sq_sum += (actions ** 2).sum(axis=0)
        n_action += len(actions)

        if (i + 1) % 1000 == 0:
            elapsed = time.time() - start
            rate = (i + 1) / elapsed
            eta = (len(episode_set) - i - 1) / rate
            print(f"  [{i+1}/{len(episode_set)}] {rate:.0f} eps/s, ETA: {eta:.0f}s")

    state_mean = state_sum / n_state
    state_std = np.sqrt(state_sq_sum / n_state - state_mean ** 2)

    action_mean = action_sum / n_action
    action_std = np.sqrt(action_sq_sum / n_action - action_mean ** 2)

    elapsed = time.time() - start
    print(f"Done in {elapsed:.1f}s ({n_state:,} state frames, {n_action:,} action frames)")

    print(f"\nState mean: {state_mean}")
    print(f"State std:  {state_std}")
    print(f"Action mean: {action_mean}")
    print(f"Action std:  {action_std}")

    stats = {
        "observation.state": {
            "mean": state_mean.tolist(),
            "std": state_std.tolist(),
        },
        "action": {
            "mean": action_mean.tolist(),
            "std": action_std.tolist(),
        },
    }
    return stats


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-root", type=Path, default=Path.home() / "lap" / "community_dataset_v3")
    parser.add_argument("--index", type=Path, default=Path(__file__).parent / "filtered_index.json")
    parser.add_argument("--output", type=Path, default=Path(__file__).parent / "norm_stats.json")
    args = parser.parse_args()

    stats = compute_stats(args.data_root, args.index)

    with open(args.output, "w") as f:
        json.dump(stats, f, indent=2)
    print(f"\nSaved to {args.output}")