| import dataclasses |
|
|
| import jax |
| import numpy as np |
|
|
| from openpi.models import pi0_config |
| from openpi.training import config as _config |
| from openpi.training import data_loader as _data_loader |
|
|
|
|
| class _RawColumns: |
| def __init__(self, columns): |
| self._columns = columns |
| self.column_names = list(columns) |
|
|
| def with_format(self, _format): |
| return self |
|
|
| def __getitem__(self, key): |
| return self._columns[key] |
|
|
|
|
| class _OnlineSourceDataset: |
| def __init__(self, length): |
| actions = np.zeros((length, 7), dtype=np.float32) |
| actions[:, 0] = 1.0 |
| states = np.stack( |
| [np.full(8, i, dtype=np.float32) for i in range(length)], |
| axis=0, |
| ) |
| columns = { |
| "actions": [row for row in actions], |
| "state": [row for row in states], |
| "episode_index": list(np.zeros(length, dtype=np.int64)), |
| "frame_index": list(np.arange(length, dtype=np.int64)), |
| "task_index": list(np.zeros(length, dtype=np.int64)), |
| } |
| self.hf_dataset = _RawColumns(columns) |
| self._items = [ |
| { |
| "actions": actions[i], |
| "state": states[i], |
| "episode_index": np.asarray(0, dtype=np.int64), |
| "frame_index": np.asarray(i, dtype=np.int64), |
| "task_index": np.asarray(0, dtype=np.int64), |
| } |
| for i in range(length) |
| ] |
|
|
| def __getitem__(self, index): |
| return self._items[index.__index__()] |
|
|
| def __len__(self): |
| return len(self._items) |
|
|
|
|
| def test_online_sliding_chunk_dataset_enumerates_episode_speed_pairs(): |
| """One sample per (episode, speed). Phase + chunk row are random per access.""" |
| source = _OnlineSourceDataset(12) |
|
|
| dataset = _data_loader.OnlineSlidingChunkDataset( |
| source, [0.75, 1.0, 1.25], action_horizon=4 |
| ) |
| |
| assert len(dataset) == 3 |
|
|
|
|
| def test_online_sliding_chunk_dataset_returns_aligned_chunk_start(): |
| """Every access returns mask=1 and the chosen state/action_horizon shape.""" |
| source = _OnlineSourceDataset(12) |
| dataset = _data_loader.OnlineSlidingChunkDataset(source, [1.25], action_horizon=4) |
|
|
| np.random.seed(0) |
| item = dataset[0] |
|
|
| assert item["actions"].shape == (4, 7) |
| assert float(item["speed"][0]) == 1.25 |
| assert item["speed_label"] == "1p25x" |
| assert int(item["observation_mask"]) == 1 |
| |
| |
| assert item["state"].shape == (8,) |
|
|
|
|
| def test_online_sliding_chunk_dataset_speed_one_fast_path(): |
| """speed=1.0 bypasses transform_episode and reads source verbatim.""" |
| source = _OnlineSourceDataset(12) |
| dataset = _data_loader.OnlineSlidingChunkDataset(source, [1.0], action_horizon=4) |
|
|
| np.random.seed(0) |
| item = dataset[0] |
| assert float(item["speed"][0]) == 1.0 |
| assert int(item["observation_mask"]) == 1 |
| assert item["actions"].shape == (4, 7) |
| |
| assert np.all(item["actions"][:, 0] == 1.0) |
|
|
|
|
| def test_torch_data_loader(): |
| config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) |
| dataset = _data_loader.FakeDataset(config, 16) |
|
|
| loader = _data_loader.TorchDataLoader( |
| dataset, |
| local_batch_size=4, |
| num_batches=2, |
| ) |
| batches = list(loader) |
|
|
| assert len(batches) == 2 |
| for batch in batches: |
| assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) |
|
|
|
|
| def test_torch_data_loader_infinite(): |
| config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) |
| dataset = _data_loader.FakeDataset(config, 4) |
|
|
| loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) |
| data_iter = iter(loader) |
|
|
| for _ in range(10): |
| _ = next(data_iter) |
|
|
|
|
| def test_torch_data_loader_parallel(): |
| config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) |
| dataset = _data_loader.FakeDataset(config, 10) |
|
|
| loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) |
| batches = list(loader) |
|
|
| assert len(batches) == 2 |
|
|
| for batch in batches: |
| assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) |
|
|
|
|
| def test_with_fake_dataset(): |
| config = _config.get_config("debug") |
|
|
| loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) |
| batches = list(loader) |
|
|
| assert len(batches) == 2 |
|
|
| for batch in batches: |
| assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) |
|
|
| for _, actions in batches: |
| assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) |
|
|
|
|
| def test_with_real_dataset(): |
| config = _config.get_config("pi0_aloha_sim") |
| config = dataclasses.replace(config, batch_size=4) |
|
|
| loader = _data_loader.create_data_loader( |
| config, |
| |
| skip_norm_stats=True, |
| num_batches=2, |
| shuffle=True, |
| ) |
| |
| assert loader.data_config().repo_id == config.data.repo_id |
|
|
| batches = list(loader) |
|
|
| assert len(batches) == 2 |
|
|
| for _, actions in batches: |
| assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) |
|
|