FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling
Compression: 576x (576 β 36 tokens, 1152 β 32 channels)
Overview
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
Architecture
- Encoder (39.99M params): Reshape to 2D β 2Γ [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] β Self-attention + SwiGLU FFN β Linear(1152β32) + VAE ΞΌ/logvar heads
- Decoder (138.86M params): Linear(32β1152) β 6-layer ViT (RoPE at 6Γ6 grid, RMSNorm, SwiGLU) β 2Γ [ConvTranspose2d upsample] β RMSNorm β [B, 576, 1152]
Compression Details
| Input | Latent | Compression | |
|---|---|---|---|
| Tokens | 576 (24Γ24) | 36 (6Γ6) | 16x spatial |
| Channels | 1152 | 32 | 36x channel |
| Total values | 663,552 | 1,152 | 576x |
Results
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | 0.9648 |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 18/20 (90%) |
The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x).
Training
- Dataset: Zebra-CoT Chess (19,983 train / 500 eval images)
- Base model: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- Loss: MSE + Ξ²Β·KL (Ξ²=1e-6)
- Optimizer: AdamW, lr=1e-4, weight_decay=0.05
- Epochs: 100
- Hardware: 2x NVIDIA L40S with DDP
- Feature normalization: Per-dim mean/std (included in checkpoint)
Files
fae_encoder.ptβ Spatial FAE encoder weightsfeature_decoder.ptβ Spatial FAE decoder weightstraining_state.ptβ Training metadata + feature normalization statsfae_spatial.pyβ Model architecture source code
Usage
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 36, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support