File size: 3,818 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Image trajectory dataset for world-model training."""

from __future__ import annotations

import numpy as np
import torch
from torch.utils.data import Dataset

from driftwm.data.generate import ID_TO_BOAT
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array


class ImageTrajectoryDataset(Dataset):
    def __init__(
        self,
        source_npz: str,
        history_len: int,
        horizon: int,
        episodes: int,
        max_windows: int,
        seed: int,
        episode_start: int = 0,
        image_size: int = 160,
        visual_scale: float = 2.5,
        render_images: bool = True,
        return_origin: bool = False,
        return_aux: bool = False,
    ):
        src = np.load(source_npz, allow_pickle=False)
        start = int(episode_start)
        end = start + int(episodes)
        if end > int(src["obs"].shape[0]):
            raise ValueError(f"requested episodes [{start}, {end}) but dataset only has {src['obs'].shape[0]}")
        self.obs = src["obs"][start:end].astype(np.float32)
        self.actions = src["actions"][start:end].astype(np.float32)
        self.flow_type_ids = src["flow_type_ids"][start:end].astype(np.int64)
        self.traj_type_ids = src["traj_type_ids"][start:end].astype(np.int64)
        self.boat_ids = src["boat_ids"][start:end].astype(np.int64)
        self.states = src["states"][start:end].astype(np.float32)
        self.image_size = int(image_size)
        self.visual_scale = float(visual_scale)
        self.render_images = bool(render_images)
        self.history_len = int(history_len)
        self.horizon = int(horizon)
        steps = self.actions.shape[1]
        indices = [(ep, t) for ep in range(episodes) for t in range(self.history_len - 1, steps - self.horizon)]
        rng = np.random.default_rng(seed)
        selected = rng.choice(len(indices), size=min(max_windows, len(indices)), replace=False)
        self.indices = [indices[int(i)] for i in selected]
        self.return_origin = bool(return_origin)
        self.return_aux = bool(return_aux)

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        ep, t = self.indices[idx]
        state_hist = self.states[ep, t - self.history_len + 1 : t + 1, :6].copy()
        if self.render_images:
            boat = ID_TO_BOAT[int(self.boat_ids[ep])]
            rendered = [
                render_clean_boat_array(state, boat, image_size=self.image_size, visual_scale=self.visual_scale)
                for state in state_hist
            ]
            observation_hist = np.transpose(np.stack(rendered, axis=0), (0, 3, 1, 2))
        else:
            observation_hist = state_hist
        padded_actions = np.zeros((self.actions.shape[1] + 1, self.actions.shape[2]), dtype=np.float32)
        padded_actions[1:] = self.actions[ep]
        action_hist = padded_actions[t - self.history_len + 1 : t + 1].copy()
        future_actions = self.actions[ep, t : t + self.horizon].copy()
        targets = self.obs[ep, t + 1 : t + 1 + self.horizon].copy()
        sample = (
            torch.from_numpy(observation_hist),
            torch.from_numpy(action_hist),
            torch.from_numpy(future_actions),
            torch.from_numpy(targets),
        )
        if self.return_aux:
            return (
                *sample,
                torch.from_numpy(self.obs[ep, t].copy()),
                torch.from_numpy(self.obs[ep, t - 1].copy()),
                torch.tensor(self.flow_type_ids[ep], dtype=torch.long),
                torch.tensor(self.boat_ids[ep], dtype=torch.long),
            )
        if self.return_origin:
            return (*sample, torch.from_numpy(self.obs[ep, t].copy()))
        return sample