File size: 1,397 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from __future__ import annotations

import numpy as np

from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset


def test_image_dataset_action_history_uses_executed_actions(tmp_path) -> None:
    path = tmp_path / "image_tiny.npz"
    steps = 8
    actions = np.arange(steps * 3, dtype=np.float32).reshape(1, steps, 3)
    obs = np.zeros((1, steps + 1, 4), dtype=np.float32)
    states = np.zeros((1, steps + 1, 9), dtype=np.float32)
    np.savez_compressed(
        path,
        obs=obs,
        actions=actions,
        states=states,
        true_flow=np.zeros((1, steps + 1, 2), dtype=np.float32),
        boat_ids=np.zeros((1,), dtype=np.int64),
        action_dims=np.full((1,), 2, dtype=np.int64),
        flow_type_ids=np.zeros((1,), dtype=np.int64),
        flow_ids=np.zeros((1,), dtype=np.int64),
        traj_type_ids=np.zeros((1,), dtype=np.int64),
    )
    ds = ImageTrajectoryDataset(
        path,
        history_len=4,
        horizon=2,
        episodes=1,
        max_windows=1,
        seed=0,
        render_images=False,
    )
    _states, action_hist, future_actions, _targets = ds[0]
    _ep, t = ds.indices[0]
    padded = np.zeros((steps + 1, 3), dtype=np.float32)
    padded[1:] = actions[0]
    np.testing.assert_allclose(action_hist.numpy(), padded[t - 3 : t + 1])
    np.testing.assert_allclose(future_actions.numpy(), actions[0, t : t + 2])