| """Image trajectory dataset for world-model training.""" |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from driftwm.data.generate import ID_TO_BOAT |
| from experiments.shared.src.vision.clean_renderer import render_clean_boat_array |
|
|
|
|
| class ImageTrajectoryDataset(Dataset): |
| def __init__( |
| self, |
| source_npz: str, |
| history_len: int, |
| horizon: int, |
| episodes: int, |
| max_windows: int, |
| seed: int, |
| episode_start: int = 0, |
| image_size: int = 160, |
| visual_scale: float = 2.5, |
| render_images: bool = True, |
| return_origin: bool = False, |
| return_aux: bool = False, |
| ): |
| src = np.load(source_npz, allow_pickle=False) |
| start = int(episode_start) |
| end = start + int(episodes) |
| if end > int(src["obs"].shape[0]): |
| raise ValueError(f"requested episodes [{start}, {end}) but dataset only has {src['obs'].shape[0]}") |
| self.obs = src["obs"][start:end].astype(np.float32) |
| self.actions = src["actions"][start:end].astype(np.float32) |
| self.flow_type_ids = src["flow_type_ids"][start:end].astype(np.int64) |
| self.traj_type_ids = src["traj_type_ids"][start:end].astype(np.int64) |
| self.boat_ids = src["boat_ids"][start:end].astype(np.int64) |
| self.states = src["states"][start:end].astype(np.float32) |
| self.image_size = int(image_size) |
| self.visual_scale = float(visual_scale) |
| self.render_images = bool(render_images) |
| self.history_len = int(history_len) |
| self.horizon = int(horizon) |
| steps = self.actions.shape[1] |
| indices = [(ep, t) for ep in range(episodes) for t in range(self.history_len - 1, steps - self.horizon)] |
| rng = np.random.default_rng(seed) |
| selected = rng.choice(len(indices), size=min(max_windows, len(indices)), replace=False) |
| self.indices = [indices[int(i)] for i in selected] |
| self.return_origin = bool(return_origin) |
| self.return_aux = bool(return_aux) |
|
|
| def __len__(self) -> int: |
| return len(self.indices) |
|
|
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| ep, t = self.indices[idx] |
| state_hist = self.states[ep, t - self.history_len + 1 : t + 1, :6].copy() |
| if self.render_images: |
| boat = ID_TO_BOAT[int(self.boat_ids[ep])] |
| rendered = [ |
| render_clean_boat_array(state, boat, image_size=self.image_size, visual_scale=self.visual_scale) |
| for state in state_hist |
| ] |
| observation_hist = np.transpose(np.stack(rendered, axis=0), (0, 3, 1, 2)) |
| else: |
| observation_hist = state_hist |
| padded_actions = np.zeros((self.actions.shape[1] + 1, self.actions.shape[2]), dtype=np.float32) |
| padded_actions[1:] = self.actions[ep] |
| action_hist = padded_actions[t - self.history_len + 1 : t + 1].copy() |
| future_actions = self.actions[ep, t : t + self.horizon].copy() |
| targets = self.obs[ep, t + 1 : t + 1 + self.horizon].copy() |
| sample = ( |
| torch.from_numpy(observation_hist), |
| torch.from_numpy(action_hist), |
| torch.from_numpy(future_actions), |
| torch.from_numpy(targets), |
| ) |
| if self.return_aux: |
| return ( |
| *sample, |
| torch.from_numpy(self.obs[ep, t].copy()), |
| torch.from_numpy(self.obs[ep, t - 1].copy()), |
| torch.tensor(self.flow_type_ids[ep], dtype=torch.long), |
| torch.tensor(self.boat_ids[ep], dtype=torch.long), |
| ) |
| if self.return_origin: |
| return (*sample, torch.from_numpy(self.obs[ep, t].copy())) |
| return sample |
|
|