File size: 1,970 Bytes
d93804e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Compute norm stats for the Kinova TeddyBear dataset without decoding videos.

This matches the training-time state/action preprocessing for the custom
`pi05_kinova_teddybear` config:
- state stats are computed on raw `observation.state`
- action stats are computed on 16-step action chunks after converting the
  first 6 dimensions from absolute targets to deltas relative to the current state
  while leaving the gripper dimension absolute
"""

from pathlib import Path

import numpy as np
import pandas as pd

from openpi.shared import normalize
from openpi.training import config as config_lib
from openpi.training.kinova_lerobot_v3_dataset import locate_dataset_root


def main() -> None:
    cfg = config_lib.get_config("pi05_kinova_teddybear")
    repo_id = "lsnu/TeddyBearKinovaTestSetLeRobot"
    root = locate_dataset_root()

    data_df = pd.concat(
        [pd.read_parquet(path) for path in sorted((root / "data").rglob("*.parquet"))],
        ignore_index=True,
    ).sort_values("index")

    state_stats = normalize.RunningStats()
    action_stats = normalize.RunningStats()
    horizon = cfg.model.action_horizon

    for _, ep_rows in data_df.groupby("episode_index", sort=True):
        states = np.stack(ep_rows["observation.state"].to_list()).astype(np.float32)
        actions = np.stack(ep_rows["action"].to_list()).astype(np.float32)
        length = len(states)

        state_stats.update(states)

        idx = np.arange(length)[:, None] + np.arange(horizon)[None, :]
        idx = np.clip(idx, 0, length - 1)
        action_chunks = actions[idx]
        action_chunks[..., :6] -= states[:, None, :6]
        action_stats.update(action_chunks)

    output_path = cfg.assets_dirs / repo_id
    norm_stats = {
        "state": state_stats.get_statistics(),
        "actions": action_stats.get_statistics(),
    }
    print(f"Writing stats to: {output_path}")
    normalize.save(output_path, norm_stats)


if __name__ == "__main__":
    main()