"""Image-input world models used by paper-facing experiments.""" from __future__ import annotations from dataclasses import dataclass import torch from torch import nn from experiments.shared.src.models.image_components import ImageEncoder, MLP, encode_image_sequence @dataclass class ImageWorldModelConfig: image_size: int = 96 action_dim: int = 3 emb_dim: int = 96 z_dim: int = 64 c_dim: int = 16 hidden_dim: int = 160 history_len: int = 8 context_len: int = 32 context_stride: int = 4 rollout_horizon: int = 8 class LeWorldModelImage(nn.Module): def __init__(self, config: ImageWorldModelConfig): super().__init__() self.config = config self.encoder = ImageEncoder(config.emb_dim) self.to_z = MLP(config.emb_dim, config.z_dim, config.hidden_dim, depth=1) self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: emb = encode_image_sequence(self.encoder, images[:, -1:]) z = self.to_z(emb[:, -1]) c = z.new_zeros((z.shape[0], 0)) return z, c def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: return z + self.transition(torch.cat([z, action], dim=-1)) def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: z, c = self.encode(images, actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = self.step(cur, future_actions[:, t], c) preds.append(self.decoder(cur)) return torch.stack(preds, dim=1) class HistoryImageWorldModel(nn.Module): def __init__(self, config: ImageWorldModelConfig): super().__init__() self.config = config self.encoder = ImageEncoder(config.emb_dim) self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) act = actions[:, -self.config.history_len :] h, _ = self.history(torch.cat([emb, act], dim=-1)) z = self.to_z(h[:, -1]) c = z.new_zeros((z.shape[0], 0)) return z, c def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: return z + self.transition(torch.cat([z, action], dim=-1)) def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: z, c = self.encode(images, actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = self.step(cur, future_actions[:, t], c) preds.append(self.decoder(cur)) return torch.stack(preds, dim=1) class FlowMoImageWorldModel(nn.Module): def __init__(self, config: ImageWorldModelConfig): super().__init__() self.config = config self.encoder = ImageEncoder(config.emb_dim) self.state_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) self.context_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) self.to_c = MLP(config.hidden_dim, config.c_dim, config.hidden_dim, depth=1) self.base_delta = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) self.residual_delta = MLP(config.z_dim + config.c_dim, config.z_dim, config.hidden_dim, depth=2) self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) def selected_history_indices(self, total_length: int) -> list[int]: total = int(total_length) context_start = total - self.config.context_len context = list(range(context_start, total, self.config.context_stride)) state = list(range(total - self.config.history_len, total)) return context + state def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: compact_context = len(range(0, self.config.context_len, self.config.context_stride)) compact_len = compact_context + self.config.history_len if images.shape[1] == compact_len: context_images = images[:, :compact_context] context_actions = actions[:, :compact_context] state_images = images[:, -self.config.history_len :] state_actions = actions[:, -self.config.history_len :] else: state_images = images[:, -self.config.history_len :] state_actions = actions[:, -self.config.history_len :] context_images = images[:, -self.config.context_len :] context_actions = actions[:, -self.config.context_len :] if self.config.context_stride > 1: context_images = context_images[:, :: self.config.context_stride] context_actions = context_actions[:, :: self.config.context_stride] state_emb = encode_image_sequence(self.encoder, state_images) context_emb = encode_image_sequence(self.encoder, context_images) state_h, _ = self.state_history(torch.cat([state_emb, state_actions], dim=-1)) context_h, _ = self.context_history(torch.cat([context_emb, context_actions], dim=-1)) return self.to_z(state_h[:, -1]), self.to_c(context_h[:, -1]) def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: base = self.base_delta(torch.cat([z, action], dim=-1)) r = self.residual_delta(torch.cat([z, c], dim=-1)) r0 = self.residual_delta(torch.cat([z, torch.zeros_like(c)], dim=-1)) return z + base + r - r0 def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: z, c = self.encode(images, actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = self.step(cur, future_actions[:, t], c) preds.append(self.decoder(cur)) return torch.stack(preds, dim=1) class RSSMImageWorldModel(nn.Module): def __init__(self, config: ImageWorldModelConfig): super().__init__() self.config = config self.encoder = ImageEncoder(config.emb_dim) self.recurrent = nn.GRUCell(config.z_dim + config.action_dim, config.hidden_dim) self.posterior = MLP(config.hidden_dim + config.emb_dim, config.z_dim, config.hidden_dim, depth=1) self.prior = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) self.decoder = MLP(config.hidden_dim + config.z_dim, 4, config.hidden_dim, depth=2) def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) act = actions[:, -self.config.history_len :] h = emb.new_zeros((emb.shape[0], self.config.hidden_dim)) z = emb.new_zeros((emb.shape[0], self.config.z_dim)) for t in range(emb.shape[1]): h = self.recurrent(torch.cat([z, act[:, t]], dim=-1), h) z = self.posterior(torch.cat([h, emb[:, t]], dim=-1)) state = torch.cat([h, z], dim=-1) c = state.new_zeros((state.shape[0], 0)) return state, c def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: h, stochastic = torch.split(z, [self.config.hidden_dim, self.config.z_dim], dim=-1) h_next = self.recurrent(torch.cat([stochastic, action], dim=-1), h) stochastic_next = self.prior(h_next) return torch.cat([h_next, stochastic_next], dim=-1) def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: z, c = self.encode(images, actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = self.step(cur, future_actions[:, t], c) preds.append(self.decoder(cur)) return torch.stack(preds, dim=1) class TDMPC2ImageWorldModel(nn.Module): def __init__(self, config: ImageWorldModelConfig): super().__init__() self.config = config self.encoder = ImageEncoder(config.emb_dim) self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) act = actions[:, -self.config.history_len :] h, _ = self.history(torch.cat([emb, act], dim=-1)) z = self.to_z(h[:, -1]) c = z.new_zeros((z.shape[0], 0)) return z, c def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: return z + self.transition(torch.cat([z, action], dim=-1)) def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: z, c = self.encode(images, actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = self.step(cur, future_actions[:, t], c) preds.append(self.decoder(cur)) return torch.stack(preds, dim=1)