| | """FAE with CNN spatial pooling for token reduction. |
| | |
| | Encoder: CNN downsample (24×24 → H'×W') + self-attention + project to latent_dim |
| | Decoder: project up + ViT layers at compressed resolution + CNN upsample (H'×W' → 24×24) |
| | |
| | pool_factor=2: 576 → 144 tokens (s2) |
| | pool_factor=4: 576 → 36 tokens (s4) |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| | import sys, os |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from utils import RMSNorm |
| | from models.feature_decoder import RotaryPositionalEmbedding2D, ViTDecoderBlock |
| |
|
| |
|
| | class CNNDownsample(nn.Module): |
| | """Spatial downsampling with strided convolutions. |
| | Each layer does 2x downsample. Stacks log2(pool_factor) layers. |
| | """ |
| |
|
| | def __init__(self, dim, pool_factor): |
| | super().__init__() |
| | assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" |
| | num_layers = int(math.log2(pool_factor)) |
| | layers = [] |
| | for _ in range(num_layers): |
| | layers.extend([ |
| | nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1), |
| | nn.GELU(), |
| | ]) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | """x: [B, C, H, W] → [B, C, H/pf, W/pf]""" |
| | return self.net(x) |
| |
|
| |
|
| | class CNNUpsample(nn.Module): |
| | """Spatial upsampling with transposed convolutions. |
| | Each layer does 2x upsample. Stacks log2(pool_factor) layers. |
| | """ |
| |
|
| | def __init__(self, dim, pool_factor): |
| | super().__init__() |
| | assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" |
| | num_layers = int(math.log2(pool_factor)) |
| | layers = [] |
| | for _ in range(num_layers): |
| | layers.extend([ |
| | nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1), |
| | nn.GELU(), |
| | ]) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | """x: [B, C, H', W'] → [B, C, H'*pf, W'*pf]""" |
| | return self.net(x) |
| |
|
| |
|
| | class FAESpatialEncoder(nn.Module): |
| | """FAE Encoder with CNN spatial pooling. |
| | |
| | Input: [B, 576, embed_dim] |
| | Output: [B, N_compressed, latent_dim] |
| | where N_compressed = (24/pool_factor)^2 |
| | """ |
| |
|
| | def __init__(self, embed_dim=1152, latent_dim=32, num_heads=16, |
| | pool_factor=2, grid_size=24, use_vae=True): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.latent_dim = latent_dim |
| | self.pool_factor = pool_factor |
| | self.grid_size = grid_size |
| | self.compressed_grid = grid_size // pool_factor |
| | self.use_vae = use_vae |
| |
|
| | |
| | self.downsample = CNNDownsample(embed_dim, pool_factor) |
| |
|
| | |
| | self.norm1 = RMSNorm(embed_dim) |
| | self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
| |
|
| | |
| | self.norm2 = RMSNorm(embed_dim) |
| | ffn_dim = int(embed_dim * 2.7) |
| | self.w1 = nn.Linear(embed_dim, ffn_dim, bias=False) |
| | self.w2 = nn.Linear(ffn_dim, embed_dim, bias=False) |
| | self.w3 = nn.Linear(embed_dim, ffn_dim, bias=False) |
| |
|
| | |
| | self.proj = nn.Linear(embed_dim, latent_dim) |
| |
|
| | |
| | if use_vae: |
| | self.mu_head = nn.Linear(latent_dim, latent_dim) |
| | self.logvar_head = nn.Linear(latent_dim, latent_dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: [B, N, embed_dim] where N = grid_size^2 = 576 |
| | Returns: |
| | z_sample: [B, N_compressed, latent_dim] |
| | mu, logvar: same shape |
| | """ |
| | B, N, D = x.shape |
| |
|
| | |
| | x = x.transpose(1, 2).reshape(B, D, self.grid_size, self.grid_size) |
| | x = self.downsample(x) |
| | x = x.flatten(2).transpose(1, 2) |
| |
|
| | |
| | normed = self.norm1(x) |
| | x = x + self.self_attn(normed, normed, normed)[0] |
| |
|
| | |
| | h = self.norm2(x) |
| | x = x + self.w2(F.silu(self.w1(h)) * self.w3(h)) |
| |
|
| | |
| | z = self.proj(x) |
| |
|
| | if not self.use_vae: |
| | return z, z, torch.zeros_like(z) |
| |
|
| | mu = self.mu_head(z) |
| | logvar = self.logvar_head(z) |
| |
|
| | if self.training: |
| | std = torch.exp(0.5 * logvar) |
| | z_sample = mu + std * torch.randn_like(std) |
| | else: |
| | z_sample = mu |
| |
|
| | return z_sample, mu, logvar |
| |
|
| |
|
| | class FAESpatialDecoder(nn.Module): |
| | """FAE Decoder with CNN spatial upsampling. |
| | |
| | Input: [B, N_compressed, latent_dim] |
| | Output: [B, 576, output_dim] |
| | |
| | ViT layers operate at compressed resolution, then CNN upsamples. |
| | """ |
| |
|
| | def __init__(self, latent_dim=32, output_dim=1152, num_layers=6, |
| | num_heads=16, ffn_mult=2.7, pool_factor=2, grid_size=24): |
| | super().__init__() |
| | self.output_dim = output_dim |
| | self.pool_factor = pool_factor |
| | self.grid_size = grid_size |
| | self.compressed_grid = grid_size // pool_factor |
| |
|
| | |
| | self.input_proj = nn.Linear(latent_dim, output_dim) |
| |
|
| | |
| | head_dim = output_dim // num_heads |
| | self.rope = RotaryPositionalEmbedding2D(head_dim, grid_size=self.compressed_grid) |
| |
|
| | |
| | self.layers = nn.ModuleList([ |
| | ViTDecoderBlock(output_dim, num_heads, ffn_mult) |
| | for _ in range(num_layers) |
| | ]) |
| | self.pre_upsample_norm = RMSNorm(output_dim) |
| |
|
| | |
| | self.upsample = CNNUpsample(output_dim, pool_factor) |
| |
|
| | |
| | self.final_norm = RMSNorm(output_dim) |
| |
|
| | def forward(self, z): |
| | """ |
| | Args: |
| | z: [B, N_compressed, latent_dim] |
| | Returns: |
| | x_hat: [B, N_full, output_dim] where N_full = grid_size^2 |
| | """ |
| | B = z.shape[0] |
| | x = self.input_proj(z) |
| |
|
| | rope_cos, rope_sin = self.rope(x.shape[1], x.device) |
| |
|
| | for layer in self.layers: |
| | x = layer(x, rope_cos, rope_sin) |
| |
|
| | x = self.pre_upsample_norm(x) |
| |
|
| | |
| | x = x.transpose(1, 2).reshape(B, self.output_dim, |
| | self.compressed_grid, self.compressed_grid) |
| | x = self.upsample(x) |
| | x = x.flatten(2).transpose(1, 2) |
| |
|
| | return self.final_norm(x) |
| |
|