| """Shared image model components.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class ImageEncoder(nn.Module): |
| def __init__(self, emb_dim: int = 96): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(5, 24, kernel_size=5, stride=2, padding=2), |
| nn.SiLU(), |
| nn.Conv2d(24, 48, kernel_size=3, stride=2, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(48, 64, kernel_size=3, stride=2, padding=1), |
| nn.SiLU(), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| nn.SiLU(), |
| nn.AdaptiveAvgPool2d((8, 8)), |
| nn.Flatten(), |
| nn.Linear(64 * 8 * 8, emb_dim), |
| nn.LayerNorm(emb_dim), |
| ) |
|
|
| def forward(self, images: torch.Tensor) -> torch.Tensor: |
| x = images.float() / 255.0 |
| if x.is_cuda: |
| x = x.contiguous(memory_format=torch.channels_last) |
| b, _c, h, w = x.shape |
| yy = torch.linspace(-1.0, 1.0, h, device=x.device, dtype=x.dtype).view(1, 1, h, 1).expand(b, 1, h, w) |
| xx = torch.linspace(-1.0, 1.0, w, device=x.device, dtype=x.dtype).view(1, 1, 1, w).expand(b, 1, h, w) |
| x = torch.cat([x, xx, yy], dim=1) |
| return self.net(x) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 160, depth: int = 2): |
| super().__init__() |
| layers: list[nn.Module] = [] |
| dim = in_dim |
| for _ in range(depth): |
| layers.append(nn.Linear(dim, hidden_dim)) |
| layers.append(nn.SiLU()) |
| dim = hidden_dim |
| layers.append(nn.Linear(dim, out_dim)) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| def encode_image_sequence(encoder: ImageEncoder, images: torch.Tensor) -> torch.Tensor: |
| b, t, c, h, w = images.shape |
| emb = encoder(images.reshape(b * t, c, h, w)) |
| return emb.reshape(b, t, -1) |
|
|