pi05tests-openpi-multiarm / openpi /scripts /inspect_twin_packed_batch.py
lsnu's picture
Add files using upload-large-folder tool
be75ccd verified
#!/usr/bin/env python3
import dataclasses
import json
import copy
from pathlib import Path
import numpy as np
import tyro
import openpi.training.config as _config
import openpi.training.data_loader as _data
import openpi.transforms as _transforms
@dataclasses.dataclass
class Args:
config_name: str
repo_id: str | None = None
index: int = 0
def _array_str(array: np.ndarray) -> str:
return np.array2string(array, precision=4, suppress_small=False, threshold=10000)
def main() -> None:
args = tyro.cli(Args)
config = _config.get_config(args.config_name)
data_config = config.data.create(config.assets_dirs, config.model)
if args.repo_id is not None:
data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
dataset = _data.create_torch_dataset(data_config, config.model.action_horizon, config.model)
sample = dataset[args.index]
repack_and_data = _transforms.compose(
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
]
)
normalized = _transforms.compose(
[
_transforms.Normalize(data_config.norm_stats, use_quantiles=data_config.use_quantile_norm),
]
)
model_inputs = _transforms.compose(list(data_config.model_transforms.inputs))
raw_sample = repack_and_data(sample)
normalized_sample = normalized(copy.deepcopy(raw_sample))
packed_sample = model_inputs(copy.deepcopy(normalized_sample))
norm_stats_path = config.assets_dirs / data_config.asset_id / "norm_stats.json"
norm_stats = json.loads(Path(norm_stats_path).read_text())["norm_stats"]
packed_zero_positions = list(range(8, 16)) + list(range(24, 32))
state_padded = packed_sample["state"][packed_zero_positions]
actions_padded = packed_sample["actions"][:, packed_zero_positions]
print(f"config_name: {args.config_name}")
print(f"repo_id: {data_config.repo_id}")
print(f"sample_index: {args.index}")
print(f"norm_stats_path: {norm_stats_path}")
print(f"norm_stats_keys: {sorted(norm_stats.keys())}")
print(
"norm_stats_lengths: "
f"state_mean={len(norm_stats['state']['mean'])} "
f"state_std={len(norm_stats['state']['std'])} "
f"action_mean={len(norm_stats['actions']['mean'])} "
f"action_std={len(norm_stats['actions']['std'])}"
)
print("block_boundaries: [0:8] [8:16] [16:24] [24:32]")
print(f"raw_state_16d_shape: {raw_sample['state'].shape}")
print(f"raw_state_16d:\n{_array_str(np.asarray(raw_sample['state']))}")
print(f"raw_actions_16d_shape: {raw_sample['actions'].shape}")
print(f"raw_actions_16d:\n{_array_str(np.asarray(raw_sample['actions']))}")
print(f"normalized_state_16d_shape: {normalized_sample['state'].shape}")
print(f"normalized_state_16d:\n{_array_str(np.asarray(normalized_sample['state']))}")
print(f"normalized_actions_16d_shape: {normalized_sample['actions'].shape}")
print(f"normalized_actions_16d:\n{_array_str(np.asarray(normalized_sample['actions']))}")
print(f"packed_state_32d_shape: {packed_sample['state'].shape}")
print(f"packed_state_32d:\n{_array_str(np.asarray(packed_sample['state']))}")
print(f"packed_actions_32d_shape: {packed_sample['actions'].shape}")
print(f"packed_actions_32d:\n{_array_str(np.asarray(packed_sample['actions']))}")
print(f"state_padded_zero_count: {int((state_padded == 0).sum())} / {state_padded.size}")
print(f"actions_padded_zero_count: {int((actions_padded == 0).sum())} / {actions_padded.size}")
print(f"state_padded_exact_zero: {bool(np.all(state_padded == 0))}")
print(f"actions_padded_exact_zero: {bool(np.all(actions_padded == 0))}")
if __name__ == "__main__":
main()