File size: 1,397 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | from __future__ import annotations
import numpy as np
from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset
def test_image_dataset_action_history_uses_executed_actions(tmp_path) -> None:
path = tmp_path / "image_tiny.npz"
steps = 8
actions = np.arange(steps * 3, dtype=np.float32).reshape(1, steps, 3)
obs = np.zeros((1, steps + 1, 4), dtype=np.float32)
states = np.zeros((1, steps + 1, 9), dtype=np.float32)
np.savez_compressed(
path,
obs=obs,
actions=actions,
states=states,
true_flow=np.zeros((1, steps + 1, 2), dtype=np.float32),
boat_ids=np.zeros((1,), dtype=np.int64),
action_dims=np.full((1,), 2, dtype=np.int64),
flow_type_ids=np.zeros((1,), dtype=np.int64),
flow_ids=np.zeros((1,), dtype=np.int64),
traj_type_ids=np.zeros((1,), dtype=np.int64),
)
ds = ImageTrajectoryDataset(
path,
history_len=4,
horizon=2,
episodes=1,
max_windows=1,
seed=0,
render_images=False,
)
_states, action_hist, future_actions, _targets = ds[0]
_ep, t = ds.indices[0]
padded = np.zeros((steps + 1, 3), dtype=np.float32)
padded[1:] = actions[0]
np.testing.assert_allclose(action_hist.numpy(), padded[t - 3 : t + 1])
np.testing.assert_allclose(future_actions.numpy(), actions[0, t : t + 2])
|