"""FAE with CNN spatial pooling for token reduction. Encoder: CNN downsample (24×24 → H'×W') + self-attention + project to latent_dim Decoder: project up + ViT layers at compressed resolution + CNN upsample (H'×W' → 24×24) pool_factor=2: 576 → 144 tokens (s2) pool_factor=4: 576 → 36 tokens (s4) """ import torch import torch.nn as nn import torch.nn.functional as F import math import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils import RMSNorm from models.feature_decoder import RotaryPositionalEmbedding2D, ViTDecoderBlock class CNNDownsample(nn.Module): """Spatial downsampling with strided convolutions. Each layer does 2x downsample. Stacks log2(pool_factor) layers. """ def __init__(self, dim, pool_factor): super().__init__() assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" num_layers = int(math.log2(pool_factor)) layers = [] for _ in range(num_layers): layers.extend([ nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1), nn.GELU(), ]) self.net = nn.Sequential(*layers) def forward(self, x): """x: [B, C, H, W] → [B, C, H/pf, W/pf]""" return self.net(x) class CNNUpsample(nn.Module): """Spatial upsampling with transposed convolutions. Each layer does 2x upsample. Stacks log2(pool_factor) layers. """ def __init__(self, dim, pool_factor): super().__init__() assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" num_layers = int(math.log2(pool_factor)) layers = [] for _ in range(num_layers): layers.extend([ nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1), nn.GELU(), ]) self.net = nn.Sequential(*layers) def forward(self, x): """x: [B, C, H', W'] → [B, C, H'*pf, W'*pf]""" return self.net(x) class FAESpatialEncoder(nn.Module): """FAE Encoder with CNN spatial pooling. Input: [B, 576, embed_dim] Output: [B, N_compressed, latent_dim] where N_compressed = (24/pool_factor)^2 """ def __init__(self, embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, grid_size=24, use_vae=True): super().__init__() self.embed_dim = embed_dim self.latent_dim = latent_dim self.pool_factor = pool_factor self.grid_size = grid_size self.compressed_grid = grid_size // pool_factor self.use_vae = use_vae # CNN spatial downsampling self.downsample = CNNDownsample(embed_dim, pool_factor) # Self-attention at compressed resolution (pre-norm) self.norm1 = RMSNorm(embed_dim) self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) # SwiGLU FFN self.norm2 = RMSNorm(embed_dim) ffn_dim = int(embed_dim * 2.7) self.w1 = nn.Linear(embed_dim, ffn_dim, bias=False) self.w2 = nn.Linear(ffn_dim, embed_dim, bias=False) self.w3 = nn.Linear(embed_dim, ffn_dim, bias=False) # Per-token projection to latent dim self.proj = nn.Linear(embed_dim, latent_dim) # VAE heads if use_vae: self.mu_head = nn.Linear(latent_dim, latent_dim) self.logvar_head = nn.Linear(latent_dim, latent_dim) def forward(self, x): """ Args: x: [B, N, embed_dim] where N = grid_size^2 = 576 Returns: z_sample: [B, N_compressed, latent_dim] mu, logvar: same shape """ B, N, D = x.shape # Reshape to 2D and downsample x = x.transpose(1, 2).reshape(B, D, self.grid_size, self.grid_size) x = self.downsample(x) # [B, D, H', W'] x = x.flatten(2).transpose(1, 2) # [B, N_compressed, D] # Self-attention normed = self.norm1(x) x = x + self.self_attn(normed, normed, normed)[0] # SwiGLU FFN h = self.norm2(x) x = x + self.w2(F.silu(self.w1(h)) * self.w3(h)) # Project to latent z = self.proj(x) if not self.use_vae: return z, z, torch.zeros_like(z) mu = self.mu_head(z) logvar = self.logvar_head(z) if self.training: std = torch.exp(0.5 * logvar) z_sample = mu + std * torch.randn_like(std) else: z_sample = mu return z_sample, mu, logvar class FAESpatialDecoder(nn.Module): """FAE Decoder with CNN spatial upsampling. Input: [B, N_compressed, latent_dim] Output: [B, 576, output_dim] ViT layers operate at compressed resolution, then CNN upsamples. """ def __init__(self, latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2, grid_size=24): super().__init__() self.output_dim = output_dim self.pool_factor = pool_factor self.grid_size = grid_size self.compressed_grid = grid_size // pool_factor # Project latent up to full dim self.input_proj = nn.Linear(latent_dim, output_dim) # RoPE at compressed grid resolution head_dim = output_dim // num_heads self.rope = RotaryPositionalEmbedding2D(head_dim, grid_size=self.compressed_grid) # Transformer layers at compressed resolution self.layers = nn.ModuleList([ ViTDecoderBlock(output_dim, num_heads, ffn_mult) for _ in range(num_layers) ]) self.pre_upsample_norm = RMSNorm(output_dim) # CNN spatial upsampling self.upsample = CNNUpsample(output_dim, pool_factor) # Final projection after upsample (refine features) self.final_norm = RMSNorm(output_dim) def forward(self, z): """ Args: z: [B, N_compressed, latent_dim] Returns: x_hat: [B, N_full, output_dim] where N_full = grid_size^2 """ B = z.shape[0] x = self.input_proj(z) # [B, N_compressed, output_dim] rope_cos, rope_sin = self.rope(x.shape[1], x.device) for layer in self.layers: x = layer(x, rope_cos, rope_sin) x = self.pre_upsample_norm(x) # Reshape to 2D and upsample x = x.transpose(1, 2).reshape(B, self.output_dim, self.compressed_grid, self.compressed_grid) x = self.upsample(x) # [B, output_dim, grid_size, grid_size] x = x.flatten(2).transpose(1, 2) # [B, N_full, output_dim] return self.final_norm(x)