Spaces:
Runtime error
Runtime error
| """Flow matching U-Net dynamics model. | |
| Predicts the velocity field v(x_t, t, action, context) for generating the | |
| next game frame. Operates in pixel space — takes 128×128×3 images directly. | |
| Architecture (from docs/build_spec.md §2.3): | |
| - Input: [B, 15, H, W] = noisy target (3ch) + 4 context frames (12ch) | |
| - Conditioning: flow time t + action → AdaGN in every ResBlock | |
| - 4-level encoder-decoder with skip connections | |
| - Self-attention only at 16×16 (smallest spatial resolution) | |
| - Output: [B, 3, H, W] predicted velocity field | |
| Reference: docs/build_spec.md §2.3-2.4, docs/foundations_guide.md Part 6-7. | |
| """ | |
| import random | |
| from typing import List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.models.blocks import ( | |
| Downsample, | |
| ResBlock, | |
| SelfAttention, | |
| SinusoidalEmbedding, | |
| Upsample, | |
| ) | |
| class UNet(nn.Module): | |
| """Flow matching U-Net with action conditioning via AdaGN. | |
| The U-Net processes a concatenation of the noisy interpolated target and | |
| context frames, conditioned on flow time t and action a. Every ResBlock | |
| receives the combined conditioning embedding via Adaptive GroupNorm, | |
| which modulates feature processing based on "what time step are we at" | |
| and "what action was taken." | |
| Args: | |
| in_channels: Input channels (15 = 3 noisy + 4×3 context). | |
| out_channels: Output channels (3 = velocity field RGB). | |
| channel_mult: Channel counts at each U-Net level. | |
| cond_dim: Conditioning embedding dimension (time + action combined). | |
| num_groups: Groups for GroupNorm in all ResBlocks. | |
| num_actions: Size of discrete action space (Procgen = 15). | |
| action_embed_dim: Intermediate action embedding size before projection. | |
| time_embed_dim: Sinusoidal embedding dimension before MLP. | |
| cfg_dropout: Probability of zeroing action embedding during training | |
| (for classifier-free guidance). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 15, | |
| out_channels: int = 3, | |
| channel_mult: List[int] = [64, 128, 256, 512], | |
| cond_dim: int = 512, | |
| num_groups: int = 32, | |
| num_actions: int = 15, | |
| action_embed_dim: int = 256, | |
| time_embed_dim: int = 256, | |
| cfg_dropout: float = 0.1, | |
| ) -> None: | |
| super().__init__() | |
| self.cfg_dropout = cfg_dropout | |
| self.cond_dim = cond_dim | |
| ch = channel_mult # shorthand: [64, 128, 256, 512] | |
| # --- Conditioning embeddings --- | |
| # Time: sinusoidal → MLP(256 → 512 → 512) | |
| self.time_embed = SinusoidalEmbedding(dim=time_embed_dim) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(time_embed_dim, cond_dim), | |
| nn.SiLU(), | |
| nn.Linear(cond_dim, cond_dim), | |
| ) | |
| # Action: Embedding(15, 256) → MLP(256 → 512 → 512) | |
| self.action_embed = nn.Embedding(num_actions, action_embed_dim) | |
| self.action_mlp = nn.Sequential( | |
| nn.Linear(action_embed_dim, cond_dim), | |
| nn.SiLU(), | |
| nn.Linear(cond_dim, cond_dim), | |
| ) | |
| # --- Input convolution --- | |
| self.input_conv = nn.Conv2d(in_channels, ch[0], kernel_size=3, padding=1) | |
| # --- Down path --- | |
| # Level 1: ResBlock(64→64) × 2, Downsample(64) | |
| self.down1_res1 = ResBlock(ch[0], ch[0], cond_dim, num_groups) | |
| self.down1_res2 = ResBlock(ch[0], ch[0], cond_dim, num_groups) | |
| self.down1_downsample = Downsample(ch[0]) | |
| # Level 2: ResBlock(64→128), ResBlock(128→128), Downsample(128) | |
| self.down2_res1 = ResBlock(ch[0], ch[1], cond_dim, num_groups) | |
| self.down2_res2 = ResBlock(ch[1], ch[1], cond_dim, num_groups) | |
| self.down2_downsample = Downsample(ch[1]) | |
| # Level 3: ResBlock(128→256), ResBlock(256→256), Downsample(256) | |
| self.down3_res1 = ResBlock(ch[1], ch[2], cond_dim, num_groups) | |
| self.down3_res2 = ResBlock(ch[2], ch[2], cond_dim, num_groups) | |
| self.down3_downsample = Downsample(ch[2]) | |
| # Level 4: ResBlock(256→512), ResBlock(512→512) — no downsample | |
| self.down4_res1 = ResBlock(ch[2], ch[3], cond_dim, num_groups) | |
| self.down4_res2 = ResBlock(ch[3], ch[3], cond_dim, num_groups) | |
| # --- Middle --- | |
| self.mid_res1 = ResBlock(ch[3], ch[3], cond_dim, num_groups) | |
| self.mid_attn = SelfAttention(ch[3], num_heads=4, num_groups=num_groups) | |
| self.mid_res2 = ResBlock(ch[3], ch[3], cond_dim, num_groups) | |
| # --- Up path (with skip connections doubling input channels) --- | |
| # Level 4: cat(512+512)=1024→512, 512→256, Upsample(256) | |
| self.up4_res1 = ResBlock(ch[3] + ch[3], ch[3], cond_dim, num_groups) | |
| self.up4_res2 = ResBlock(ch[3], ch[2], cond_dim, num_groups) | |
| self.up4_upsample = Upsample(ch[2]) | |
| # Level 3: cat(256+256)=512→256, 256→128, Upsample(128) | |
| self.up3_res1 = ResBlock(ch[2] + ch[2], ch[2], cond_dim, num_groups) | |
| self.up3_res2 = ResBlock(ch[2], ch[1], cond_dim, num_groups) | |
| self.up3_upsample = Upsample(ch[1]) | |
| # Level 2: cat(128+128)=256→128, 128→64, Upsample(64) | |
| self.up2_res1 = ResBlock(ch[1] + ch[1], ch[1], cond_dim, num_groups) | |
| self.up2_res2 = ResBlock(ch[1], ch[0], cond_dim, num_groups) | |
| self.up2_upsample = Upsample(ch[0]) | |
| # Level 1: cat(64+64)=128→64, 64→64 — no upsample | |
| self.up1_res1 = ResBlock(ch[0] + ch[0], ch[0], cond_dim, num_groups) | |
| self.up1_res2 = ResBlock(ch[0], ch[0], cond_dim, num_groups) | |
| # --- Output --- | |
| self.out_norm = nn.GroupNorm(num_groups, ch[0]) | |
| self.out_conv = nn.Conv2d(ch[0], out_channels, kernel_size=3, padding=1) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| action: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Predict velocity field for flow matching. | |
| Args: | |
| x: U-Net input = concat(noisy_target, context_frames), | |
| shape [B, in_channels, H, W]. | |
| t: Flow time, shape [B]. Values in [0, 1]. | |
| action: Action indices, shape [B]. Integers in [0, num_actions). | |
| Returns: | |
| Predicted velocity field, shape [B, out_channels, H, W]. | |
| """ | |
| # --- Build conditioning embedding --- | |
| # Time: [B] → sinusoidal [B, 256] → MLP [B, 512] | |
| t_emb = self.time_mlp(self.time_embed(t)) # [B, cond_dim] | |
| # Action: [B] → Embedding [B, 256] → MLP [B, 512] | |
| a_emb = self.action_mlp(self.action_embed(action)) # [B, cond_dim] | |
| # Classifier-free guidance dropout: during training, zero out the | |
| # action embedding for cfg_dropout fraction of the batch. | |
| if self.training and self.cfg_dropout > 0: | |
| B = action.shape[0] | |
| keep_mask = (torch.rand(B, device=action.device) >= self.cfg_dropout).float() | |
| a_emb = a_emb * keep_mask.unsqueeze(1) # [B, cond_dim] | |
| # Combined conditioning: additive (both are [B, 512]) | |
| cond = t_emb + a_emb # [B, cond_dim] | |
| # --- Input conv --- | |
| h = self.input_conv(x) # [B, 64, H, W] | |
| # --- Down path (save features for skip connections) --- | |
| # Level 1 | |
| h = self.down1_res1(h, cond) | |
| h = self.down1_res2(h, cond) | |
| skip1 = h | |
| h = self.down1_downsample(h) | |
| # Level 2 | |
| h = self.down2_res1(h, cond) | |
| h = self.down2_res2(h, cond) | |
| skip2 = h | |
| h = self.down2_downsample(h) | |
| # Level 3 | |
| h = self.down3_res1(h, cond) | |
| h = self.down3_res2(h, cond) | |
| skip3 = h | |
| h = self.down3_downsample(h) | |
| # Level 4 (no downsample — this is the bottleneck resolution) | |
| h = self.down4_res1(h, cond) | |
| h = self.down4_res2(h, cond) | |
| skip4 = h | |
| # --- Middle --- | |
| h = self.mid_res1(h, cond) | |
| h = self.mid_attn(h) | |
| h = self.mid_res2(h, cond) | |
| # --- Up path (concatenate skip connections) --- | |
| # Level 4 | |
| h = torch.cat([h, skip4], dim=1) | |
| h = self.up4_res1(h, cond) | |
| h = self.up4_res2(h, cond) | |
| h = self.up4_upsample(h) | |
| # Level 3 | |
| h = torch.cat([h, skip3], dim=1) | |
| h = self.up3_res1(h, cond) | |
| h = self.up3_res2(h, cond) | |
| h = self.up3_upsample(h) | |
| # Level 2 | |
| h = torch.cat([h, skip2], dim=1) | |
| h = self.up2_res1(h, cond) | |
| h = self.up2_res2(h, cond) | |
| h = self.up2_upsample(h) | |
| # Level 1 | |
| h = torch.cat([h, skip1], dim=1) | |
| h = self.up1_res1(h, cond) | |
| h = self.up1_res2(h, cond) | |
| # --- Output --- | |
| h = F.silu(self.out_norm(h)) | |
| h = self.out_conv(h) # [B, 3, H, W] | |
| return h | |