| --- |
| license: apache-2.0 |
| tags: |
| - vision |
| - compression |
| - feature-autoencoder |
| - qwen3-vl |
| datasets: |
| - multimodal-reasoning-lab/Zebra-CoT |
| --- |
| |
| # FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling |
|
|
| **Compression**: 144x (576 β 144 tokens, 1152 β 32 channels) |
|
|
| ## Overview |
|
|
| Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT. |
|
|
| ## Architecture |
|
|
| - **Encoder** (28.05M params): Reshape to 2D β Conv2d(1152, 1152, k=3, s=2, p=1) + GELU β Self-attention + SwiGLU FFN β Linear(1152β32) + VAE ΞΌ/logvar heads |
| - **Decoder** (117.63M params): Linear(32β1152) β 6-layer ViT (RoPE at 12Γ12 grid, RMSNorm, SwiGLU) β ConvTranspose2d upsample β RMSNorm β [B, 576, 1152] |
|
|
| ## Compression Details |
|
|
| | | Input | Latent | Compression | |
| |---|---|---|---| |
| | Tokens | 576 (24Γ24) | 144 (12Γ12) | 4x spatial | |
| | Channels | 1152 | 32 | 36x channel | |
| | **Total values** | **663,552** | **4,608** | **144x** | |
|
|
| ## Results |
|
|
| | Metric | Value | |
| |---|---| |
| | Eval CosSim (feature reconstruction) | **0.9776** | |
| | VLM Chess MCQ Accuracy | 3/20 (15%) | |
| | VLM Agreement with uncompressed baseline | 17/20 (85%) | |
|
|
| The s2 model **exceeds** the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed. |
|
|
| ## Training |
|
|
| - **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images) |
| - **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152) |
| - **Loss**: MSE + Ξ²Β·KL (Ξ²=1e-6) |
| - **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05 |
| - **Epochs**: 100 |
| - **Hardware**: 2x NVIDIA L40S with DDP |
| - **Feature normalization**: Per-dim mean/std (included in checkpoint) |
|
|
| ## Files |
|
|
| - `fae_encoder.pt` β Spatial FAE encoder weights |
| - `feature_decoder.pt` β Spatial FAE decoder weights |
| - `training_state.pt` β Training metadata + feature normalization stats |
| - `fae_spatial.py` β Model architecture source code |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from fae_spatial import FAESpatialEncoder, FAESpatialDecoder |
| |
| # Load checkpoint |
| state = torch.load("training_state.pt", map_location="cpu") |
| feat_mean = state["feat_mean"].cuda() |
| feat_std = state["feat_std"].cuda() |
| |
| encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True) |
| encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu")) |
| encoder = encoder.cuda().eval() |
| |
| decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2) |
| decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu")) |
| decoder = decoder.cuda().eval() |
| |
| # Compress ViT features [B, 576, 1152] |
| vit_features_norm = (vit_features - feat_mean) / feat_std |
| z, mu, logvar = encoder(vit_features_norm) # [B, 144, 32] |
| reconstructed = decoder(z) # [B, 576, 1152] |
| reconstructed = reconstructed * feat_std + feat_mean |
| ``` |
|
|