File size: 1,992 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | """Shared image model components."""
from __future__ import annotations
import torch
from torch import nn
class ImageEncoder(nn.Module):
def __init__(self, emb_dim: int = 96):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(5, 24, kernel_size=5, stride=2, padding=2),
nn.SiLU(),
nn.Conv2d(24, 48, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(48, 64, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((8, 8)),
nn.Flatten(),
nn.Linear(64 * 8 * 8, emb_dim),
nn.LayerNorm(emb_dim),
)
def forward(self, images: torch.Tensor) -> torch.Tensor:
x = images.float() / 255.0
if x.is_cuda:
x = x.contiguous(memory_format=torch.channels_last)
b, _c, h, w = x.shape
yy = torch.linspace(-1.0, 1.0, h, device=x.device, dtype=x.dtype).view(1, 1, h, 1).expand(b, 1, h, w)
xx = torch.linspace(-1.0, 1.0, w, device=x.device, dtype=x.dtype).view(1, 1, 1, w).expand(b, 1, h, w)
x = torch.cat([x, xx, yy], dim=1)
return self.net(x)
class MLP(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 160, depth: int = 2):
super().__init__()
layers: list[nn.Module] = []
dim = in_dim
for _ in range(depth):
layers.append(nn.Linear(dim, hidden_dim))
layers.append(nn.SiLU())
dim = hidden_dim
layers.append(nn.Linear(dim, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def encode_image_sequence(encoder: ImageEncoder, images: torch.Tensor) -> torch.Tensor:
b, t, c, h, w = images.shape
emb = encoder(images.reshape(b * t, c, h, w))
return emb.reshape(b, t, -1)
|