| | --- |
| | license: apache-2.0 |
| | tags: |
| | - vision |
| | - compression |
| | - feature-autoencoder |
| | - qwen3-vl |
| | datasets: |
| | - multimodal-reasoning-lab/Zebra-CoT |
| | --- |
| | |
| | # FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling |
| |
|
| | **Compression**: 576x (576 β 36 tokens, 1152 β 32 channels) |
| |
|
| | ## Overview |
| |
|
| | Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT. |
| |
|
| | ## Architecture |
| |
|
| | - **Encoder** (39.99M params): Reshape to 2D β 2Γ [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] β Self-attention + SwiGLU FFN β Linear(1152β32) + VAE ΞΌ/logvar heads |
| | - **Decoder** (138.86M params): Linear(32β1152) β 6-layer ViT (RoPE at 6Γ6 grid, RMSNorm, SwiGLU) β 2Γ [ConvTranspose2d upsample] β RMSNorm β [B, 576, 1152] |
| |
|
| | ## Compression Details |
| |
|
| | | | Input | Latent | Compression | |
| | |---|---|---|---| |
| | | Tokens | 576 (24Γ24) | 36 (6Γ6) | 16x spatial | |
| | | Channels | 1152 | 32 | 36x channel | |
| | | **Total values** | **663,552** | **1,152** | **576x** | |
| |
|
| | ## Results |
| |
|
| | | Metric | Value | |
| | |---|---| |
| | | Eval CosSim (feature reconstruction) | **0.9648** | |
| | | VLM Chess MCQ Accuracy | 3/20 (15%) | |
| | | VLM Agreement with uncompressed baseline | 18/20 (90%) | |
| |
|
| | The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x). |
| |
|
| | ## Training |
| |
|
| | - **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images) |
| | - **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152) |
| | - **Loss**: MSE + Ξ²Β·KL (Ξ²=1e-6) |
| | - **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05 |
| | - **Epochs**: 100 |
| | - **Hardware**: 2x NVIDIA L40S with DDP |
| | - **Feature normalization**: Per-dim mean/std (included in checkpoint) |
| |
|
| | ## Files |
| |
|
| | - `fae_encoder.pt` β Spatial FAE encoder weights |
| | - `feature_decoder.pt` β Spatial FAE decoder weights |
| | - `training_state.pt` β Training metadata + feature normalization stats |
| | - `fae_spatial.py` β Model architecture source code |
| |
|
| | ## Usage |
| |
|
| | ```python |
| | import torch |
| | from fae_spatial import FAESpatialEncoder, FAESpatialDecoder |
| | |
| | # Load checkpoint |
| | state = torch.load("training_state.pt", map_location="cpu") |
| | feat_mean = state["feat_mean"].cuda() |
| | feat_std = state["feat_std"].cuda() |
| | |
| | encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True) |
| | encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu")) |
| | encoder = encoder.cuda().eval() |
| | |
| | decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4) |
| | decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu")) |
| | decoder = decoder.cuda().eval() |
| | |
| | # Compress ViT features [B, 576, 1152] |
| | vit_features_norm = (vit_features - feat_mean) / feat_std |
| | z, mu, logvar = encoder(vit_features_norm) # [B, 36, 32] |
| | reconstructed = decoder(z) # [B, 576, 1152] |
| | reconstructed = reconstructed * feat_std + feat_mean |
| | ``` |
| |
|