FlowMo-WM / experiments /shared /src /models /image_world_models.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
"""Image-input world models used by paper-facing experiments."""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
from experiments.shared.src.models.image_components import ImageEncoder, MLP, encode_image_sequence
@dataclass
class ImageWorldModelConfig:
image_size: int = 96
action_dim: int = 3
emb_dim: int = 96
z_dim: int = 64
c_dim: int = 16
hidden_dim: int = 160
history_len: int = 8
context_len: int = 32
context_stride: int = 4
rollout_horizon: int = 8
class LeWorldModelImage(nn.Module):
def __init__(self, config: ImageWorldModelConfig):
super().__init__()
self.config = config
self.encoder = ImageEncoder(config.emb_dim)
self.to_z = MLP(config.emb_dim, config.z_dim, config.hidden_dim, depth=1)
self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2)
self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2)
def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
emb = encode_image_sequence(self.encoder, images[:, -1:])
z = self.to_z(emb[:, -1])
c = z.new_zeros((z.shape[0], 0))
return z, c
def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return z + self.transition(torch.cat([z, action], dim=-1))
def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
z, c = self.encode(images, actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = self.step(cur, future_actions[:, t], c)
preds.append(self.decoder(cur))
return torch.stack(preds, dim=1)
class HistoryImageWorldModel(nn.Module):
def __init__(self, config: ImageWorldModelConfig):
super().__init__()
self.config = config
self.encoder = ImageEncoder(config.emb_dim)
self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True)
self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1)
self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2)
self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2)
def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :])
act = actions[:, -self.config.history_len :]
h, _ = self.history(torch.cat([emb, act], dim=-1))
z = self.to_z(h[:, -1])
c = z.new_zeros((z.shape[0], 0))
return z, c
def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return z + self.transition(torch.cat([z, action], dim=-1))
def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
z, c = self.encode(images, actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = self.step(cur, future_actions[:, t], c)
preds.append(self.decoder(cur))
return torch.stack(preds, dim=1)
class FlowMoImageWorldModel(nn.Module):
def __init__(self, config: ImageWorldModelConfig):
super().__init__()
self.config = config
self.encoder = ImageEncoder(config.emb_dim)
self.state_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True)
self.context_history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True)
self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1)
self.to_c = MLP(config.hidden_dim, config.c_dim, config.hidden_dim, depth=1)
self.base_delta = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2)
self.residual_delta = MLP(config.z_dim + config.c_dim, config.z_dim, config.hidden_dim, depth=2)
self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2)
def selected_history_indices(self, total_length: int) -> list[int]:
total = int(total_length)
context_start = total - self.config.context_len
context = list(range(context_start, total, self.config.context_stride))
state = list(range(total - self.config.history_len, total))
return context + state
def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
compact_context = len(range(0, self.config.context_len, self.config.context_stride))
compact_len = compact_context + self.config.history_len
if images.shape[1] == compact_len:
context_images = images[:, :compact_context]
context_actions = actions[:, :compact_context]
state_images = images[:, -self.config.history_len :]
state_actions = actions[:, -self.config.history_len :]
else:
state_images = images[:, -self.config.history_len :]
state_actions = actions[:, -self.config.history_len :]
context_images = images[:, -self.config.context_len :]
context_actions = actions[:, -self.config.context_len :]
if self.config.context_stride > 1:
context_images = context_images[:, :: self.config.context_stride]
context_actions = context_actions[:, :: self.config.context_stride]
state_emb = encode_image_sequence(self.encoder, state_images)
context_emb = encode_image_sequence(self.encoder, context_images)
state_h, _ = self.state_history(torch.cat([state_emb, state_actions], dim=-1))
context_h, _ = self.context_history(torch.cat([context_emb, context_actions], dim=-1))
return self.to_z(state_h[:, -1]), self.to_c(context_h[:, -1])
def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
base = self.base_delta(torch.cat([z, action], dim=-1))
r = self.residual_delta(torch.cat([z, c], dim=-1))
r0 = self.residual_delta(torch.cat([z, torch.zeros_like(c)], dim=-1))
return z + base + r - r0
def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
z, c = self.encode(images, actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = self.step(cur, future_actions[:, t], c)
preds.append(self.decoder(cur))
return torch.stack(preds, dim=1)
class RSSMImageWorldModel(nn.Module):
def __init__(self, config: ImageWorldModelConfig):
super().__init__()
self.config = config
self.encoder = ImageEncoder(config.emb_dim)
self.recurrent = nn.GRUCell(config.z_dim + config.action_dim, config.hidden_dim)
self.posterior = MLP(config.hidden_dim + config.emb_dim, config.z_dim, config.hidden_dim, depth=1)
self.prior = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1)
self.decoder = MLP(config.hidden_dim + config.z_dim, 4, config.hidden_dim, depth=2)
def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :])
act = actions[:, -self.config.history_len :]
h = emb.new_zeros((emb.shape[0], self.config.hidden_dim))
z = emb.new_zeros((emb.shape[0], self.config.z_dim))
for t in range(emb.shape[1]):
h = self.recurrent(torch.cat([z, act[:, t]], dim=-1), h)
z = self.posterior(torch.cat([h, emb[:, t]], dim=-1))
state = torch.cat([h, z], dim=-1)
c = state.new_zeros((state.shape[0], 0))
return state, c
def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
h, stochastic = torch.split(z, [self.config.hidden_dim, self.config.z_dim], dim=-1)
h_next = self.recurrent(torch.cat([stochastic, action], dim=-1), h)
stochastic_next = self.prior(h_next)
return torch.cat([h_next, stochastic_next], dim=-1)
def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
z, c = self.encode(images, actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = self.step(cur, future_actions[:, t], c)
preds.append(self.decoder(cur))
return torch.stack(preds, dim=1)
class TDMPC2ImageWorldModel(nn.Module):
def __init__(self, config: ImageWorldModelConfig):
super().__init__()
self.config = config
self.encoder = ImageEncoder(config.emb_dim)
self.history = nn.GRU(config.emb_dim + config.action_dim, config.hidden_dim, batch_first=True)
self.to_z = MLP(config.hidden_dim, config.z_dim, config.hidden_dim, depth=1)
self.transition = MLP(config.z_dim + config.action_dim, config.z_dim, config.hidden_dim, depth=2)
self.decoder = MLP(config.z_dim, 4, config.hidden_dim, depth=2)
def encode(self, images: torch.Tensor, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
emb = encode_image_sequence(self.encoder, images[:, -self.config.history_len :])
act = actions[:, -self.config.history_len :]
h, _ = self.history(torch.cat([emb, act], dim=-1))
z = self.to_z(h[:, -1])
c = z.new_zeros((z.shape[0], 0))
return z, c
def step(self, z: torch.Tensor, action: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return z + self.transition(torch.cat([z, action], dim=-1))
def rollout(self, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
z, c = self.encode(images, actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = self.step(cur, future_actions[:, t], c)
preds.append(self.decoder(cur))
return torch.stack(preds, dim=1)