FlowMo-WM / experiments /shared /src /models /image_components.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
"""Shared image model components."""
from __future__ import annotations
import torch
from torch import nn
class ImageEncoder(nn.Module):
def __init__(self, emb_dim: int = 96):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(5, 24, kernel_size=5, stride=2, padding=2),
nn.SiLU(),
nn.Conv2d(24, 48, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(48, 64, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.AdaptiveAvgPool2d((8, 8)),
nn.Flatten(),
nn.Linear(64 * 8 * 8, emb_dim),
nn.LayerNorm(emb_dim),
)
def forward(self, images: torch.Tensor) -> torch.Tensor:
x = images.float() / 255.0
if x.is_cuda:
x = x.contiguous(memory_format=torch.channels_last)
b, _c, h, w = x.shape
yy = torch.linspace(-1.0, 1.0, h, device=x.device, dtype=x.dtype).view(1, 1, h, 1).expand(b, 1, h, w)
xx = torch.linspace(-1.0, 1.0, w, device=x.device, dtype=x.dtype).view(1, 1, 1, w).expand(b, 1, h, w)
x = torch.cat([x, xx, yy], dim=1)
return self.net(x)
class MLP(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 160, depth: int = 2):
super().__init__()
layers: list[nn.Module] = []
dim = in_dim
for _ in range(depth):
layers.append(nn.Linear(dim, hidden_dim))
layers.append(nn.SiLU())
dim = hidden_dim
layers.append(nn.Linear(dim, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def encode_image_sequence(encoder: ImageEncoder, images: torch.Tensor) -> torch.Tensor:
b, t, c, h, w = images.shape
emb = encoder(images.reshape(b * t, c, h, w))
return emb.reshape(b, t, -1)