fae-spatial-s4-chess / fae_spatial.py
mk322's picture
Upload fae_spatial.py with huggingface_hub
cb5938b verified
"""FAE with CNN spatial pooling for token reduction.
Encoder: CNN downsample (24×24 → H'×W') + self-attention + project to latent_dim
Decoder: project up + ViT layers at compressed resolution + CNN upsample (H'×W' → 24×24)
pool_factor=2: 576 → 144 tokens (s2)
pool_factor=4: 576 → 36 tokens (s4)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import RMSNorm
from models.feature_decoder import RotaryPositionalEmbedding2D, ViTDecoderBlock
class CNNDownsample(nn.Module):
"""Spatial downsampling with strided convolutions.
Each layer does 2x downsample. Stacks log2(pool_factor) layers.
"""
def __init__(self, dim, pool_factor):
super().__init__()
assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}"
num_layers = int(math.log2(pool_factor))
layers = []
for _ in range(num_layers):
layers.extend([
nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1),
nn.GELU(),
])
self.net = nn.Sequential(*layers)
def forward(self, x):
"""x: [B, C, H, W] → [B, C, H/pf, W/pf]"""
return self.net(x)
class CNNUpsample(nn.Module):
"""Spatial upsampling with transposed convolutions.
Each layer does 2x upsample. Stacks log2(pool_factor) layers.
"""
def __init__(self, dim, pool_factor):
super().__init__()
assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}"
num_layers = int(math.log2(pool_factor))
layers = []
for _ in range(num_layers):
layers.extend([
nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1),
nn.GELU(),
])
self.net = nn.Sequential(*layers)
def forward(self, x):
"""x: [B, C, H', W'] → [B, C, H'*pf, W'*pf]"""
return self.net(x)
class FAESpatialEncoder(nn.Module):
"""FAE Encoder with CNN spatial pooling.
Input: [B, 576, embed_dim]
Output: [B, N_compressed, latent_dim]
where N_compressed = (24/pool_factor)^2
"""
def __init__(self, embed_dim=1152, latent_dim=32, num_heads=16,
pool_factor=2, grid_size=24, use_vae=True):
super().__init__()
self.embed_dim = embed_dim
self.latent_dim = latent_dim
self.pool_factor = pool_factor
self.grid_size = grid_size
self.compressed_grid = grid_size // pool_factor
self.use_vae = use_vae
# CNN spatial downsampling
self.downsample = CNNDownsample(embed_dim, pool_factor)
# Self-attention at compressed resolution (pre-norm)
self.norm1 = RMSNorm(embed_dim)
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
# SwiGLU FFN
self.norm2 = RMSNorm(embed_dim)
ffn_dim = int(embed_dim * 2.7)
self.w1 = nn.Linear(embed_dim, ffn_dim, bias=False)
self.w2 = nn.Linear(ffn_dim, embed_dim, bias=False)
self.w3 = nn.Linear(embed_dim, ffn_dim, bias=False)
# Per-token projection to latent dim
self.proj = nn.Linear(embed_dim, latent_dim)
# VAE heads
if use_vae:
self.mu_head = nn.Linear(latent_dim, latent_dim)
self.logvar_head = nn.Linear(latent_dim, latent_dim)
def forward(self, x):
"""
Args:
x: [B, N, embed_dim] where N = grid_size^2 = 576
Returns:
z_sample: [B, N_compressed, latent_dim]
mu, logvar: same shape
"""
B, N, D = x.shape
# Reshape to 2D and downsample
x = x.transpose(1, 2).reshape(B, D, self.grid_size, self.grid_size)
x = self.downsample(x) # [B, D, H', W']
x = x.flatten(2).transpose(1, 2) # [B, N_compressed, D]
# Self-attention
normed = self.norm1(x)
x = x + self.self_attn(normed, normed, normed)[0]
# SwiGLU FFN
h = self.norm2(x)
x = x + self.w2(F.silu(self.w1(h)) * self.w3(h))
# Project to latent
z = self.proj(x)
if not self.use_vae:
return z, z, torch.zeros_like(z)
mu = self.mu_head(z)
logvar = self.logvar_head(z)
if self.training:
std = torch.exp(0.5 * logvar)
z_sample = mu + std * torch.randn_like(std)
else:
z_sample = mu
return z_sample, mu, logvar
class FAESpatialDecoder(nn.Module):
"""FAE Decoder with CNN spatial upsampling.
Input: [B, N_compressed, latent_dim]
Output: [B, 576, output_dim]
ViT layers operate at compressed resolution, then CNN upsamples.
"""
def __init__(self, latent_dim=32, output_dim=1152, num_layers=6,
num_heads=16, ffn_mult=2.7, pool_factor=2, grid_size=24):
super().__init__()
self.output_dim = output_dim
self.pool_factor = pool_factor
self.grid_size = grid_size
self.compressed_grid = grid_size // pool_factor
# Project latent up to full dim
self.input_proj = nn.Linear(latent_dim, output_dim)
# RoPE at compressed grid resolution
head_dim = output_dim // num_heads
self.rope = RotaryPositionalEmbedding2D(head_dim, grid_size=self.compressed_grid)
# Transformer layers at compressed resolution
self.layers = nn.ModuleList([
ViTDecoderBlock(output_dim, num_heads, ffn_mult)
for _ in range(num_layers)
])
self.pre_upsample_norm = RMSNorm(output_dim)
# CNN spatial upsampling
self.upsample = CNNUpsample(output_dim, pool_factor)
# Final projection after upsample (refine features)
self.final_norm = RMSNorm(output_dim)
def forward(self, z):
"""
Args:
z: [B, N_compressed, latent_dim]
Returns:
x_hat: [B, N_full, output_dim] where N_full = grid_size^2
"""
B = z.shape[0]
x = self.input_proj(z) # [B, N_compressed, output_dim]
rope_cos, rope_sin = self.rope(x.shape[1], x.device)
for layer in self.layers:
x = layer(x, rope_cos, rope_sin)
x = self.pre_upsample_norm(x)
# Reshape to 2D and upsample
x = x.transpose(1, 2).reshape(B, self.output_dim,
self.compressed_grid, self.compressed_grid)
x = self.upsample(x) # [B, output_dim, grid_size, grid_size]
x = x.flatten(2).transpose(1, 2) # [B, N_full, output_dim]
return self.final_norm(x)