"""Image trajectory dataset for world-model training.""" from __future__ import annotations import numpy as np import torch from torch.utils.data import Dataset from driftwm.data.generate import ID_TO_BOAT from experiments.shared.src.vision.clean_renderer import render_clean_boat_array class ImageTrajectoryDataset(Dataset): def __init__( self, source_npz: str, history_len: int, horizon: int, episodes: int, max_windows: int, seed: int, episode_start: int = 0, image_size: int = 160, visual_scale: float = 2.5, render_images: bool = True, return_origin: bool = False, return_aux: bool = False, ): src = np.load(source_npz, allow_pickle=False) start = int(episode_start) end = start + int(episodes) if end > int(src["obs"].shape[0]): raise ValueError(f"requested episodes [{start}, {end}) but dataset only has {src['obs'].shape[0]}") self.obs = src["obs"][start:end].astype(np.float32) self.actions = src["actions"][start:end].astype(np.float32) self.flow_type_ids = src["flow_type_ids"][start:end].astype(np.int64) self.traj_type_ids = src["traj_type_ids"][start:end].astype(np.int64) self.boat_ids = src["boat_ids"][start:end].astype(np.int64) self.states = src["states"][start:end].astype(np.float32) self.image_size = int(image_size) self.visual_scale = float(visual_scale) self.render_images = bool(render_images) self.history_len = int(history_len) self.horizon = int(horizon) steps = self.actions.shape[1] indices = [(ep, t) for ep in range(episodes) for t in range(self.history_len - 1, steps - self.horizon)] rng = np.random.default_rng(seed) selected = rng.choice(len(indices), size=min(max_windows, len(indices)), replace=False) self.indices = [indices[int(i)] for i in selected] self.return_origin = bool(return_origin) self.return_aux = bool(return_aux) def __len__(self) -> int: return len(self.indices) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ep, t = self.indices[idx] state_hist = self.states[ep, t - self.history_len + 1 : t + 1, :6].copy() if self.render_images: boat = ID_TO_BOAT[int(self.boat_ids[ep])] rendered = [ render_clean_boat_array(state, boat, image_size=self.image_size, visual_scale=self.visual_scale) for state in state_hist ] observation_hist = np.transpose(np.stack(rendered, axis=0), (0, 3, 1, 2)) else: observation_hist = state_hist padded_actions = np.zeros((self.actions.shape[1] + 1, self.actions.shape[2]), dtype=np.float32) padded_actions[1:] = self.actions[ep] action_hist = padded_actions[t - self.history_len + 1 : t + 1].copy() future_actions = self.actions[ep, t : t + self.horizon].copy() targets = self.obs[ep, t + 1 : t + 1 + self.horizon].copy() sample = ( torch.from_numpy(observation_hist), torch.from_numpy(action_hist), torch.from_numpy(future_actions), torch.from_numpy(targets), ) if self.return_aux: return ( *sample, torch.from_numpy(self.obs[ep, t].copy()), torch.from_numpy(self.obs[ep, t - 1].copy()), torch.tensor(self.flow_type_ids[ep], dtype=torch.long), torch.tensor(self.boat_ids[ep], dtype=torch.long), ) if self.return_origin: return (*sample, torch.from_numpy(self.obs[ep, t].copy())) return sample