File size: 3,818 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """Image trajectory dataset for world-model training."""
from __future__ import annotations
import numpy as np
import torch
from torch.utils.data import Dataset
from driftwm.data.generate import ID_TO_BOAT
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
class ImageTrajectoryDataset(Dataset):
def __init__(
self,
source_npz: str,
history_len: int,
horizon: int,
episodes: int,
max_windows: int,
seed: int,
episode_start: int = 0,
image_size: int = 160,
visual_scale: float = 2.5,
render_images: bool = True,
return_origin: bool = False,
return_aux: bool = False,
):
src = np.load(source_npz, allow_pickle=False)
start = int(episode_start)
end = start + int(episodes)
if end > int(src["obs"].shape[0]):
raise ValueError(f"requested episodes [{start}, {end}) but dataset only has {src['obs'].shape[0]}")
self.obs = src["obs"][start:end].astype(np.float32)
self.actions = src["actions"][start:end].astype(np.float32)
self.flow_type_ids = src["flow_type_ids"][start:end].astype(np.int64)
self.traj_type_ids = src["traj_type_ids"][start:end].astype(np.int64)
self.boat_ids = src["boat_ids"][start:end].astype(np.int64)
self.states = src["states"][start:end].astype(np.float32)
self.image_size = int(image_size)
self.visual_scale = float(visual_scale)
self.render_images = bool(render_images)
self.history_len = int(history_len)
self.horizon = int(horizon)
steps = self.actions.shape[1]
indices = [(ep, t) for ep in range(episodes) for t in range(self.history_len - 1, steps - self.horizon)]
rng = np.random.default_rng(seed)
selected = rng.choice(len(indices), size=min(max_windows, len(indices)), replace=False)
self.indices = [indices[int(i)] for i in selected]
self.return_origin = bool(return_origin)
self.return_aux = bool(return_aux)
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ep, t = self.indices[idx]
state_hist = self.states[ep, t - self.history_len + 1 : t + 1, :6].copy()
if self.render_images:
boat = ID_TO_BOAT[int(self.boat_ids[ep])]
rendered = [
render_clean_boat_array(state, boat, image_size=self.image_size, visual_scale=self.visual_scale)
for state in state_hist
]
observation_hist = np.transpose(np.stack(rendered, axis=0), (0, 3, 1, 2))
else:
observation_hist = state_hist
padded_actions = np.zeros((self.actions.shape[1] + 1, self.actions.shape[2]), dtype=np.float32)
padded_actions[1:] = self.actions[ep]
action_hist = padded_actions[t - self.history_len + 1 : t + 1].copy()
future_actions = self.actions[ep, t : t + self.horizon].copy()
targets = self.obs[ep, t + 1 : t + 1 + self.horizon].copy()
sample = (
torch.from_numpy(observation_hist),
torch.from_numpy(action_hist),
torch.from_numpy(future_actions),
torch.from_numpy(targets),
)
if self.return_aux:
return (
*sample,
torch.from_numpy(self.obs[ep, t].copy()),
torch.from_numpy(self.obs[ep, t - 1].copy()),
torch.tensor(self.flow_type_ids[ep], dtype=torch.long),
torch.tensor(self.boat_ids[ep], dtype=torch.long),
)
if self.return_origin:
return (*sample, torch.from_numpy(self.obs[ep, t].copy()))
return sample
|