File size: 3,762 Bytes
be75ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3

import dataclasses
import json
import copy
from pathlib import Path

import numpy as np
import tyro

import openpi.training.config as _config
import openpi.training.data_loader as _data
import openpi.transforms as _transforms


@dataclasses.dataclass
class Args:
    config_name: str
    repo_id: str | None = None
    index: int = 0


def _array_str(array: np.ndarray) -> str:
    return np.array2string(array, precision=4, suppress_small=False, threshold=10000)


def main() -> None:
    args = tyro.cli(Args)
    config = _config.get_config(args.config_name)
    data_config = config.data.create(config.assets_dirs, config.model)
    if args.repo_id is not None:
        data_config = dataclasses.replace(data_config, repo_id=args.repo_id)

    dataset = _data.create_torch_dataset(data_config, config.model.action_horizon, config.model)
    sample = dataset[args.index]

    repack_and_data = _transforms.compose(
        [
            *data_config.repack_transforms.inputs,
            *data_config.data_transforms.inputs,
        ]
    )
    normalized = _transforms.compose(
        [
            _transforms.Normalize(data_config.norm_stats, use_quantiles=data_config.use_quantile_norm),
        ]
    )
    model_inputs = _transforms.compose(list(data_config.model_transforms.inputs))

    raw_sample = repack_and_data(sample)
    normalized_sample = normalized(copy.deepcopy(raw_sample))
    packed_sample = model_inputs(copy.deepcopy(normalized_sample))

    norm_stats_path = config.assets_dirs / data_config.asset_id / "norm_stats.json"
    norm_stats = json.loads(Path(norm_stats_path).read_text())["norm_stats"]

    packed_zero_positions = list(range(8, 16)) + list(range(24, 32))
    state_padded = packed_sample["state"][packed_zero_positions]
    actions_padded = packed_sample["actions"][:, packed_zero_positions]

    print(f"config_name: {args.config_name}")
    print(f"repo_id: {data_config.repo_id}")
    print(f"sample_index: {args.index}")
    print(f"norm_stats_path: {norm_stats_path}")
    print(f"norm_stats_keys: {sorted(norm_stats.keys())}")
    print(
        "norm_stats_lengths: "
        f"state_mean={len(norm_stats['state']['mean'])} "
        f"state_std={len(norm_stats['state']['std'])} "
        f"action_mean={len(norm_stats['actions']['mean'])} "
        f"action_std={len(norm_stats['actions']['std'])}"
    )
    print("block_boundaries: [0:8] [8:16] [16:24] [24:32]")
    print(f"raw_state_16d_shape: {raw_sample['state'].shape}")
    print(f"raw_state_16d:\n{_array_str(np.asarray(raw_sample['state']))}")
    print(f"raw_actions_16d_shape: {raw_sample['actions'].shape}")
    print(f"raw_actions_16d:\n{_array_str(np.asarray(raw_sample['actions']))}")
    print(f"normalized_state_16d_shape: {normalized_sample['state'].shape}")
    print(f"normalized_state_16d:\n{_array_str(np.asarray(normalized_sample['state']))}")
    print(f"normalized_actions_16d_shape: {normalized_sample['actions'].shape}")
    print(f"normalized_actions_16d:\n{_array_str(np.asarray(normalized_sample['actions']))}")
    print(f"packed_state_32d_shape: {packed_sample['state'].shape}")
    print(f"packed_state_32d:\n{_array_str(np.asarray(packed_sample['state']))}")
    print(f"packed_actions_32d_shape: {packed_sample['actions'].shape}")
    print(f"packed_actions_32d:\n{_array_str(np.asarray(packed_sample['actions']))}")
    print(f"state_padded_zero_count: {int((state_padded == 0).sum())} / {state_padded.size}")
    print(f"actions_padded_zero_count: {int((actions_padded == 0).sum())} / {actions_padded.size}")
    print(f"state_padded_exact_zero: {bool(np.all(state_padded == 0))}")
    print(f"actions_padded_exact_zero: {bool(np.all(actions_padded == 0))}")


if __name__ == "__main__":
    main()