File size: 5,474 Bytes
08ff31f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import dataclasses
import jax
import numpy as np
from openpi.models import pi0_config
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
class _RawColumns:
def __init__(self, columns):
self._columns = columns
self.column_names = list(columns)
def with_format(self, _format):
return self
def __getitem__(self, key):
return self._columns[key]
class _OnlineSourceDataset:
def __init__(self, length):
actions = np.zeros((length, 7), dtype=np.float32)
actions[:, 0] = 1.0
states = np.stack(
[np.full(8, i, dtype=np.float32) for i in range(length)],
axis=0,
)
columns = {
"actions": [row for row in actions],
"state": [row for row in states],
"episode_index": list(np.zeros(length, dtype=np.int64)),
"frame_index": list(np.arange(length, dtype=np.int64)),
"task_index": list(np.zeros(length, dtype=np.int64)),
}
self.hf_dataset = _RawColumns(columns)
self._items = [
{
"actions": actions[i],
"state": states[i],
"episode_index": np.asarray(0, dtype=np.int64),
"frame_index": np.asarray(i, dtype=np.int64),
"task_index": np.asarray(0, dtype=np.int64),
}
for i in range(length)
]
def __getitem__(self, index):
return self._items[index.__index__()]
def __len__(self):
return len(self._items)
def test_online_sliding_chunk_dataset_enumerates_episode_speed_pairs():
"""One sample per (episode, speed). Phase + chunk row are random per access."""
source = _OnlineSourceDataset(12)
dataset = _data_loader.OnlineSlidingChunkDataset(
source, [0.75, 1.0, 1.25], action_horizon=4
)
# 1 episode x 3 speeds
assert len(dataset) == 3
def test_online_sliding_chunk_dataset_returns_aligned_chunk_start():
"""Every access returns mask=1 and the chosen state/action_horizon shape."""
source = _OnlineSourceDataset(12)
dataset = _data_loader.OnlineSlidingChunkDataset(source, [1.25], action_horizon=4)
np.random.seed(0)
item = dataset[0]
assert item["actions"].shape == (4, 7)
assert float(item["speed"][0]) == 1.25
assert item["speed_label"] == "1p25x"
assert int(item["observation_mask"]) == 1
# State must come from a source frame (constant per-frame in the fake
# source: state[i] = i * np.ones(8)).
assert item["state"].shape == (8,)
def test_online_sliding_chunk_dataset_speed_one_fast_path():
"""speed=1.0 bypasses transform_episode and reads source verbatim."""
source = _OnlineSourceDataset(12)
dataset = _data_loader.OnlineSlidingChunkDataset(source, [1.0], action_horizon=4)
np.random.seed(0)
item = dataset[0]
assert float(item["speed"][0]) == 1.0
assert int(item["observation_mask"]) == 1
assert item["actions"].shape == (4, 7)
# All actions in the synthetic episode have action[:, 0] == 1.0
assert np.all(item["actions"][:, 0] == 1.0)
def test_torch_data_loader():
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
dataset = _data_loader.FakeDataset(config, 16)
loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=4,
num_batches=2,
)
batches = list(loader)
assert len(batches) == 2
for batch in batches:
assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
def test_torch_data_loader_infinite():
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
dataset = _data_loader.FakeDataset(config, 4)
loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4)
data_iter = iter(loader)
for _ in range(10):
_ = next(data_iter)
def test_torch_data_loader_parallel():
config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)
dataset = _data_loader.FakeDataset(config, 10)
loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2)
batches = list(loader)
assert len(batches) == 2
for batch in batches:
assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))
def test_with_fake_dataset():
config = _config.get_config("debug")
loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2)
batches = list(loader)
assert len(batches) == 2
for batch in batches:
assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch))
for _, actions in batches:
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
def test_with_real_dataset():
config = _config.get_config("pi0_aloha_sim")
config = dataclasses.replace(config, batch_size=4)
loader = _data_loader.create_data_loader(
config,
# Skip since we may not have the data available.
skip_norm_stats=True,
num_batches=2,
shuffle=True,
)
# Make sure that we can get the data config.
assert loader.data_config().repo_id == config.data.repo_id
batches = list(loader)
assert len(batches) == 2
for _, actions in batches:
assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)
|