minigenie / src /models /blocks.py
BrutalCaesar's picture
🧞 Deploy MiniGenie — interactive flow matching world model demo
f805fb3
"""Building blocks for the flow matching U-Net.
Contains: SinusoidalEmbedding, ResBlock (with AdaGN), SelfAttention.
These are the primitives that the U-Net and VQ-VAE are composed from.
Reference: docs/build_spec.md §2.3 (U-Net architecture, ResBlock with AdaGN,
conditioning embedding) and docs/foundations_guide.md Part 7 (how AdaGN works).
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalEmbedding(nn.Module):
"""Map a scalar t ∈ [0, 1] to a fixed-frequency sinusoidal embedding.
Uses the same encoding as Vaswani et al. (2017) "Attention Is All You Need",
but applied to a continuous scalar instead of discrete positions.
Math: for dimension index i in [0, dim):
embed[2i] = sin(t * 10000^(-2i/dim))
embed[2i+1] = cos(t * 10000^(-2i/dim))
This gives the flow time a rich, high-dimensional representation where
nearby t values have similar embeddings (smooth) but the model can still
distinguish fine differences (high-frequency components).
Args:
dim: Output embedding dimension. Must be even.
"""
def __init__(self, dim: int) -> None:
super().__init__()
if dim % 2 != 0:
raise ValueError(f"SinusoidalEmbedding dim must be even, got {dim}")
self.dim = dim
# Precompute the frequency denominators: 10000^(2i/dim) for i=0..dim/2-1
# Stored as a buffer (not a parameter — no gradients needed).
half_dim = dim // 2
exponents = torch.arange(half_dim, dtype=torch.float32) / half_dim # [0, 1)
inv_freq = 1.0 / (10000.0 ** exponents) # [half_dim]
self.register_buffer("inv_freq", inv_freq)
def forward(self, t: torch.Tensor) -> torch.Tensor:
"""Embed scalar timesteps.
Args:
t: Flow time values, shape [B] or [B, 1]. Values in [0, 1].
Returns:
Embedding of shape [B, dim].
"""
# Flatten to [B]
if t.ndim == 0:
t = t.unsqueeze(0)
t = t.view(-1).float() # [B]
# Outer product: [B, 1] * [1, half_dim] → [B, half_dim]
angles = t.unsqueeze(1) * self.inv_freq.unsqueeze(0) # [B, half_dim]
# Interleave sin and cos → [B, dim]
emb = torch.cat([angles.sin(), angles.cos()], dim=-1) # [B, dim]
return emb
class ResBlock(nn.Module):
"""Residual block with Adaptive Group Normalization (AdaGN).
Architecture:
GroupNorm → SiLU → Conv3x3 → GroupNorm (modulated by cond) → SiLU → Conv3x3 → + skip
The conditioning vector (time + action embedding) modulates the second
normalization layer via learned scale and shift. This is how the U-Net
knows *what action was taken* and *what flow time step we're at* — it
adjusts the internal feature processing based on these signals.
Why AdaGN and not concatenation or cross-attention?
- The action doesn't change what's in the image — it changes *how* the image
should change. AdaGN modulates processing, which is the right inductive bias.
- Cross-attention is expensive; AdaGN is a single linear layer per block.
Reference: docs/build_spec.md §2.3 (ResBlock with AdaGN code),
docs/foundations_guide.md Part 7.
Args:
in_ch: Input channels.
out_ch: Output channels.
cond_dim: Conditioning embedding dimension (time + action).
num_groups: Number of groups for GroupNorm.
"""
def __init__(
self,
in_ch: int,
out_ch: int,
cond_dim: int = 512,
num_groups: int = 32,
) -> None:
super().__init__()
# First conv path: norm → activation → conv
self.norm1 = nn.GroupNorm(num_groups, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
# Second conv path: norm (modulated by AdaGN) → activation → conv
self.norm2 = nn.GroupNorm(num_groups, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
# AdaGN: conditioning → scale and shift for norm2
# Output is 2 * out_ch: first half is scale, second half is shift.
self.adagn = nn.Linear(cond_dim, out_ch * 2)
# Skip / residual connection: 1x1 conv if channel count changes, else identity
self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Feature map, shape [B, in_ch, H, W].
cond: Conditioning embedding, shape [B, cond_dim].
Returns:
Output feature map, shape [B, out_ch, H, W].
"""
h = F.silu(self.norm1(x))
h = self.conv1(h)
# Adaptive GroupNorm: modulate normalized features with conditioning
# scale and shift are [B, out_ch], need to broadcast to [B, out_ch, H, W]
scale, shift = self.adagn(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
h = self.norm2(h) * (1 + scale) + shift
h = F.silu(h)
h = self.conv2(h)
return h + self.skip(x)
class SelfAttention(nn.Module):
"""Standard QKV self-attention applied at 2D spatial positions.
Used only at 16×16 resolution in the U-Net (compute is O(n²) in spatial
size — at 128×128 it would be prohibitive: 16384² = 268M attention entries
vs 256² = 65K at 16×16).
Architecture: GroupNorm → QKV projection → scaled dot-product attention → output projection
Args:
channels: Number of input/output channels.
num_heads: Number of attention heads. channels must be divisible by num_heads.
num_groups: Number of groups for GroupNorm.
"""
def __init__(
self,
channels: int,
num_heads: int = 4,
num_groups: int = 32,
) -> None:
super().__init__()
if channels % num_heads != 0:
raise ValueError(
f"channels ({channels}) must be divisible by num_heads ({num_heads})"
)
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.norm = nn.GroupNorm(num_groups, channels)
# Single projection for Q, K, V concatenated
self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
self.out_proj = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply self-attention over spatial positions.
Args:
x: Feature map, shape [B, C, H, W].
Returns:
Attended feature map, shape [B, C, H, W] (residual added).
"""
residual = x
B, C, H, W = x.shape
x = self.norm(x)
# Project to Q, K, V: [B, 3*C, H, W]
qkv = self.qkv(x)
# Reshape to [B, 3, num_heads, head_dim, H*W] then permute
qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, H * W)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # each [B, num_heads, head_dim, H*W]
# Transpose to [B, num_heads, H*W, head_dim] for attention
q = q.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim]
k = k.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim]
v = v.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim]
# Scaled dot-product attention
# Using PyTorch's efficient implementation when available
attn_out = F.scaled_dot_product_attention(q, k, v) # [B, heads, H*W, head_dim]
# Reshape back to [B, C, H, W]
attn_out = attn_out.permute(0, 1, 3, 2) # [B, heads, head_dim, H*W]
attn_out = attn_out.reshape(B, C, H, W)
# Output projection + residual
return self.out_proj(attn_out) + residual
class Downsample(nn.Module):
"""Spatial downsampling via strided convolution.
Reduces spatial dimensions by 2× using a 4×4 conv with stride 2.
Learned downsampling (not just pooling) — the model learns what
information to preserve vs discard at each resolution.
Args:
channels: Number of input/output channels (preserved).
"""
def __init__(self, channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Downsample by 2×.
Args:
x: Feature map, shape [B, C, H, W].
Returns:
Downsampled feature map, shape [B, C, H//2, W//2].
"""
return self.conv(x)
class Upsample(nn.Module):
"""Spatial upsampling via nearest-neighbor interpolation + convolution.
Increases spatial dimensions by 2×. Nearest-neighbor avoids checkerboard
artifacts that transposed convolutions can produce. The subsequent conv
learns to smooth and refine the upsampled features.
Args:
channels: Number of input/output channels (preserved).
"""
def __init__(self, channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Upsample by 2×.
Args:
x: Feature map, shape [B, C, H, W].
Returns:
Upsampled feature map, shape [B, C, H*2, W*2].
"""
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)