| """Image-input world models used by paper-facing experiments.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import nn |
|
|
| from experiments.shared.src.models.image_components import ImageEncoder, MLP, encode_image_sequence |
|
|
|
|
| @dataclass |
| class ImageWorldModelConfig: |
| image_size: int = 96 |
| action_dim: int = 3 |
| emb_dim: int = 96 |
| z_dim: int = 64 |
| c_dim: int = 16 |
| hidden_dim: int = 160 |
| history_len: int = 8 |
| context_len: int = 32 |
| context_stride: int = 4 |
| rollout_horizon: int = 8 |
|
|
|
|
| class LeWorldModelImage(nn.Module): |
| def __init__(self, config: ImageWorldModelConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = ImageEncoder(config.emb_dim) |
| self.to_z = MLP(config.emb_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) |
| self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) |
|
|
| def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| emb = encode_image_sequence(self.encoder, images[:, -1:]) |
| z = self.to_z(emb[:, -1]) |
| c = z.new_zeros((z.shape[0], 0)) |
| return z, c |
|
|
| def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| return z + self.transition(torch.cat([z, action], dim=-1)) |
|
|
| def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| z, c = self.encode(images, actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = self.step(cur, future_actions[:, t], c) |
| preds.append(self.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| class HistoryImageWorldModel(nn.Module): |
| def __init__(self, config: ImageWorldModelConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = ImageEncoder(config.emb_dim) |
| self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) |
| self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) |
| self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) |
|
|
| def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) |
| act = actions[:, -self.config.history_len :] |
| h, _ = self.history(torch.cat([emb, act], dim=-1)) |
| z = self.to_z(h[:, -1]) |
| c = z.new_zeros((z.shape[0], 0)) |
| return z, c |
|
|
| def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| return z + self.transition(torch.cat([z, action], dim=-1)) |
|
|
| def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| z, c = self.encode(images, actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = self.step(cur, future_actions[:, t], c) |
| preds.append(self.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| class FlowMoImageWorldModel(nn.Module): |
| def __init__(self, config: ImageWorldModelConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = ImageEncoder(config.emb_dim) |
| self.state_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) |
| self.context_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) |
| self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.to_c = MLP(config.hidden_dim, config.c_dim, config.hidden_dim, depth=1) |
| self.base_delta = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) |
| self.residual_delta = MLP(config.z_dim + config.c_dim, config.z_dim, config.hidden_dim, depth=2) |
| self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) |
|
|
| def selected_history_indices(self, total_length: int) -> list[int]: |
| total = int(total_length) |
| context_start = total - self.config.context_len |
| context = list(range(context_start, total, self.config.context_stride)) |
| state = list(range(total - self.config.history_len, total)) |
| return context + state |
|
|
| def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| compact_context = len(range(0, self.config.context_len, self.config.context_stride)) |
| compact_len = compact_context + self.config.history_len |
| if images.shape[1] == compact_len: |
| context_images = images[:, :compact_context] |
| context_actions = actions[:, :compact_context] |
| state_images = images[:, -self.config.history_len :] |
| state_actions = actions[:, -self.config.history_len :] |
| else: |
| state_images = images[:, -self.config.history_len :] |
| state_actions = actions[:, -self.config.history_len :] |
| context_images = images[:, -self.config.context_len :] |
| context_actions = actions[:, -self.config.context_len :] |
| if self.config.context_stride > 1: |
| context_images = context_images[:, :: self.config.context_stride] |
| context_actions = context_actions[:, :: self.config.context_stride] |
| state_emb = encode_image_sequence(self.encoder, state_images) |
| context_emb = encode_image_sequence(self.encoder, context_images) |
| state_h, _ = self.state_history(torch.cat([state_emb, state_actions], dim=-1)) |
| context_h, _ = self.context_history(torch.cat([context_emb, context_actions], dim=-1)) |
| return self.to_z(state_h[:, -1]), self.to_c(context_h[:, -1]) |
|
|
| def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| base = self.base_delta(torch.cat([z, action], dim=-1)) |
| r = self.residual_delta(torch.cat([z, c], dim=-1)) |
| r0 = self.residual_delta(torch.cat([z, torch.zeros_like(c)], dim=-1)) |
| return z + base + r - r0 |
|
|
| def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| z, c = self.encode(images, actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = self.step(cur, future_actions[:, t], c) |
| preds.append(self.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| class RSSMImageWorldModel(nn.Module): |
| def __init__(self, config: ImageWorldModelConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = ImageEncoder(config.emb_dim) |
| self.recurrent = nn.GRUCell(config.z_dim + config.action_dim, config.hidden_dim) |
| self.posterior = MLP(config.hidden_dim + config.emb_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.prior = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.decoder = MLP(config.hidden_dim + config.z_dim, 4, config.hidden_dim, depth=2) |
|
|
| def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) |
| act = actions[:, -self.config.history_len :] |
| h = emb.new_zeros((emb.shape[0], self.config.hidden_dim)) |
| z = emb.new_zeros((emb.shape[0], self.config.z_dim)) |
| for t in range(emb.shape[1]): |
| h = self.recurrent(torch.cat([z, act[:, t]], dim=-1), h) |
| z = self.posterior(torch.cat([h, emb[:, t]], dim=-1)) |
| state = torch.cat([h, z], dim=-1) |
| c = state.new_zeros((state.shape[0], 0)) |
| return state, c |
|
|
| def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| h, stochastic = torch.split(z, [self.config.hidden_dim, self.config.z_dim], dim=-1) |
| h_next = self.recurrent(torch.cat([stochastic, action], dim=-1), h) |
| stochastic_next = self.prior(h_next) |
| return torch.cat([h_next, stochastic_next], dim=-1) |
|
|
| def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| z, c = self.encode(images, actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = self.step(cur, future_actions[:, t], c) |
| preds.append(self.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| class TDMPC2ImageWorldModel(nn.Module): |
| def __init__(self, config: ImageWorldModelConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = ImageEncoder(config.emb_dim) |
| self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True) |
| self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1) |
| self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2) |
| self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2) |
|
|
| def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :]) |
| act = actions[:, -self.config.history_len :] |
| h, _ = self.history(torch.cat([emb, act], dim=-1)) |
| z = self.to_z(h[:, -1]) |
| c = z.new_zeros((z.shape[0], 0)) |
| return z, c |
|
|
| def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| return z + self.transition(torch.cat([z, action], dim=-1)) |
|
|
| def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| z, c = self.encode(images, actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = self.step(cur, future_actions[:, t], c) |
| preds.append(self.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|