"""Building blocks for the flow matching U-Net. Contains: SinusoidalEmbedding, ResBlock (with AdaGN), SelfAttention. These are the primitives that the U-Net and VQ-VAE are composed from. Reference: docs/build_spec.md §2.3 (U-Net architecture, ResBlock with AdaGN, conditioning embedding) and docs/foundations_guide.md Part 7 (how AdaGN works). """ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class SinusoidalEmbedding(nn.Module): """Map a scalar t ∈ [0, 1] to a fixed-frequency sinusoidal embedding. Uses the same encoding as Vaswani et al. (2017) "Attention Is All You Need", but applied to a continuous scalar instead of discrete positions. Math: for dimension index i in [0, dim): embed[2i] = sin(t * 10000^(-2i/dim)) embed[2i+1] = cos(t * 10000^(-2i/dim)) This gives the flow time a rich, high-dimensional representation where nearby t values have similar embeddings (smooth) but the model can still distinguish fine differences (high-frequency components). Args: dim: Output embedding dimension. Must be even. """ def __init__(self, dim: int) -> None: super().__init__() if dim % 2 != 0: raise ValueError(f"SinusoidalEmbedding dim must be even, got {dim}") self.dim = dim # Precompute the frequency denominators: 10000^(2i/dim) for i=0..dim/2-1 # Stored as a buffer (not a parameter — no gradients needed). half_dim = dim // 2 exponents = torch.arange(half_dim, dtype=torch.float32) / half_dim # [0, 1) inv_freq = 1.0 / (10000.0 ** exponents) # [half_dim] self.register_buffer("inv_freq", inv_freq) def forward(self, t: torch.Tensor) -> torch.Tensor: """Embed scalar timesteps. Args: t: Flow time values, shape [B] or [B, 1]. Values in [0, 1]. Returns: Embedding of shape [B, dim]. """ # Flatten to [B] if t.ndim == 0: t = t.unsqueeze(0) t = t.view(-1).float() # [B] # Outer product: [B, 1] * [1, half_dim] → [B, half_dim] angles = t.unsqueeze(1) * self.inv_freq.unsqueeze(0) # [B, half_dim] # Interleave sin and cos → [B, dim] emb = torch.cat([angles.sin(), angles.cos()], dim=-1) # [B, dim] return emb class ResBlock(nn.Module): """Residual block with Adaptive Group Normalization (AdaGN). Architecture: GroupNorm → SiLU → Conv3x3 → GroupNorm (modulated by cond) → SiLU → Conv3x3 → + skip The conditioning vector (time + action embedding) modulates the second normalization layer via learned scale and shift. This is how the U-Net knows *what action was taken* and *what flow time step we're at* — it adjusts the internal feature processing based on these signals. Why AdaGN and not concatenation or cross-attention? - The action doesn't change what's in the image — it changes *how* the image should change. AdaGN modulates processing, which is the right inductive bias. - Cross-attention is expensive; AdaGN is a single linear layer per block. Reference: docs/build_spec.md §2.3 (ResBlock with AdaGN code), docs/foundations_guide.md Part 7. Args: in_ch: Input channels. out_ch: Output channels. cond_dim: Conditioning embedding dimension (time + action). num_groups: Number of groups for GroupNorm. """ def __init__( self, in_ch: int, out_ch: int, cond_dim: int = 512, num_groups: int = 32, ) -> None: super().__init__() # First conv path: norm → activation → conv self.norm1 = nn.GroupNorm(num_groups, in_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) # Second conv path: norm (modulated by AdaGN) → activation → conv self.norm2 = nn.GroupNorm(num_groups, out_ch) self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1) # AdaGN: conditioning → scale and shift for norm2 # Output is 2 * out_ch: first half is scale, second half is shift. self.adagn = nn.Linear(cond_dim, out_ch * 2) # Skip / residual connection: 1x1 conv if channel count changes, else identity self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity() def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: Feature map, shape [B, in_ch, H, W]. cond: Conditioning embedding, shape [B, cond_dim]. Returns: Output feature map, shape [B, out_ch, H, W]. """ h = F.silu(self.norm1(x)) h = self.conv1(h) # Adaptive GroupNorm: modulate normalized features with conditioning # scale and shift are [B, out_ch], need to broadcast to [B, out_ch, H, W] scale, shift = self.adagn(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) h = self.norm2(h) * (1 + scale) + shift h = F.silu(h) h = self.conv2(h) return h + self.skip(x) class SelfAttention(nn.Module): """Standard QKV self-attention applied at 2D spatial positions. Used only at 16×16 resolution in the U-Net (compute is O(n²) in spatial size — at 128×128 it would be prohibitive: 16384² = 268M attention entries vs 256² = 65K at 16×16). Architecture: GroupNorm → QKV projection → scaled dot-product attention → output projection Args: channels: Number of input/output channels. num_heads: Number of attention heads. channels must be divisible by num_heads. num_groups: Number of groups for GroupNorm. """ def __init__( self, channels: int, num_heads: int = 4, num_groups: int = 32, ) -> None: super().__init__() if channels % num_heads != 0: raise ValueError( f"channels ({channels}) must be divisible by num_heads ({num_heads})" ) self.num_heads = num_heads self.head_dim = channels // num_heads self.norm = nn.GroupNorm(num_groups, channels) # Single projection for Q, K, V concatenated self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1) self.out_proj = nn.Conv2d(channels, channels, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply self-attention over spatial positions. Args: x: Feature map, shape [B, C, H, W]. Returns: Attended feature map, shape [B, C, H, W] (residual added). """ residual = x B, C, H, W = x.shape x = self.norm(x) # Project to Q, K, V: [B, 3*C, H, W] qkv = self.qkv(x) # Reshape to [B, 3, num_heads, head_dim, H*W] then permute qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, H * W) q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # each [B, num_heads, head_dim, H*W] # Transpose to [B, num_heads, H*W, head_dim] for attention q = q.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim] k = k.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim] v = v.permute(0, 1, 3, 2) # [B, heads, H*W, head_dim] # Scaled dot-product attention # Using PyTorch's efficient implementation when available attn_out = F.scaled_dot_product_attention(q, k, v) # [B, heads, H*W, head_dim] # Reshape back to [B, C, H, W] attn_out = attn_out.permute(0, 1, 3, 2) # [B, heads, head_dim, H*W] attn_out = attn_out.reshape(B, C, H, W) # Output projection + residual return self.out_proj(attn_out) + residual class Downsample(nn.Module): """Spatial downsampling via strided convolution. Reduces spatial dimensions by 2× using a 4×4 conv with stride 2. Learned downsampling (not just pooling) — the model learns what information to preserve vs discard at each resolution. Args: channels: Number of input/output channels (preserved). """ def __init__(self, channels: int) -> None: super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: """Downsample by 2×. Args: x: Feature map, shape [B, C, H, W]. Returns: Downsampled feature map, shape [B, C, H//2, W//2]. """ return self.conv(x) class Upsample(nn.Module): """Spatial upsampling via nearest-neighbor interpolation + convolution. Increases spatial dimensions by 2×. Nearest-neighbor avoids checkerboard artifacts that transposed convolutions can produce. The subsequent conv learns to smooth and refine the upsampled features. Args: channels: Number of input/output channels (preserved). """ def __init__(self, channels: int) -> None: super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upsample by 2×. Args: x: Feature map, shape [B, C, H, W]. Returns: Upsampled feature map, shape [B, C, H*2, W*2]. """ x = F.interpolate(x, scale_factor=2, mode="nearest") return self.conv(x)