"""Shared image model components.""" from __future__ import annotations import torch from torch import nn class ImageEncoder(nn.Module): def __init__(self, emb_dim: int = 96): super().__init__() self.net = nn.Sequential( nn.Conv2d(5, 24, kernel_size=5, stride=2, padding=2), nn.SiLU(), nn.Conv2d(24, 48, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(48, 64, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(), nn.Linear(64 * 8 * 8, emb_dim), nn.LayerNorm(emb_dim), ) def forward(self, images: torch.Tensor) -> torch.Tensor: x = images.float() / 255.0 if x.is_cuda: x = x.contiguous(memory_format=torch.channels_last) b, _c, h, w = x.shape yy = torch.linspace(-1.0, 1.0, h, device=x.device, dtype=x.dtype).view(1, 1, h, 1).expand(b, 1, h, w) xx = torch.linspace(-1.0, 1.0, w, device=x.device, dtype=x.dtype).view(1, 1, 1, w).expand(b, 1, h, w) x = torch.cat([x, xx, yy], dim=1) return self.net(x) class MLP(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 160, depth: int = 2): super().__init__() layers: list[nn.Module] = [] dim = in_dim for _ in range(depth): layers.append(nn.Linear(dim, hidden_dim)) layers.append(nn.SiLU()) dim = hidden_dim layers.append(nn.Linear(dim, out_dim)) self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) def encode_image_sequence(encoder: ImageEncoder, images: torch.Tensor) -> torch.Tensor: b, t, c, h, w = images.shape emb = encoder(images.reshape(b * t, c, h, w)) return emb.reshape(b, t, -1)