FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling
Compression: 144x (576 β 144 tokens, 1152 β 32 channels)
Overview
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
Architecture
- Encoder (28.05M params): Reshape to 2D β Conv2d(1152, 1152, k=3, s=2, p=1) + GELU β Self-attention + SwiGLU FFN β Linear(1152β32) + VAE ΞΌ/logvar heads
- Decoder (117.63M params): Linear(32β1152) β 6-layer ViT (RoPE at 12Γ12 grid, RMSNorm, SwiGLU) β ConvTranspose2d upsample β RMSNorm β [B, 576, 1152]
Compression Details
| Input | Latent | Compression | |
|---|---|---|---|
| Tokens | 576 (24Γ24) | 144 (12Γ12) | 4x spatial |
| Channels | 1152 | 32 | 36x channel |
| Total values | 663,552 | 4,608 | 144x |
Results
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | 0.9776 |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 17/20 (85%) |
The s2 model exceeds the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.
Training
- Dataset: Zebra-CoT Chess (19,983 train / 500 eval images)
- Base model: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- Loss: MSE + Ξ²Β·KL (Ξ²=1e-6)
- Optimizer: AdamW, lr=1e-4, weight_decay=0.05
- Epochs: 100
- Hardware: 2x NVIDIA L40S with DDP
- Feature normalization: Per-dim mean/std (included in checkpoint)
Files
fae_encoder.ptβ Spatial FAE encoder weightsfeature_decoder.ptβ Spatial FAE decoder weightstraining_state.ptβ Training metadata + feature normalization statsfae_spatial.pyβ Model architecture source code
Usage
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 144, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support