File size: 2,929 Bytes
06c54ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | ---
license: apache-2.0
tags:
- vision
- compression
- feature-autoencoder
- qwen3-vl
datasets:
- multimodal-reasoning-lab/Zebra-CoT
---
# FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling
**Compression**: 576x (576 → 36 tokens, 1152 → 32 channels)
## Overview
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
## Architecture
- **Encoder** (39.99M params): Reshape to 2D → 2× [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads
- **Decoder** (138.86M params): Linear(32→1152) → 6-layer ViT (RoPE at 6×6 grid, RMSNorm, SwiGLU) → 2× [ConvTranspose2d upsample] → RMSNorm → [B, 576, 1152]
## Compression Details
| | Input | Latent | Compression |
|---|---|---|---|
| Tokens | 576 (24×24) | 36 (6×6) | 16x spatial |
| Channels | 1152 | 32 | 36x channel |
| **Total values** | **663,552** | **1,152** | **576x** |
## Results
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | **0.9648** |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 18/20 (90%) |
The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x).
## Training
- **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images)
- **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- **Loss**: MSE + β·KL (β=1e-6)
- **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05
- **Epochs**: 100
- **Hardware**: 2x NVIDIA L40S with DDP
- **Feature normalization**: Per-dim mean/std (included in checkpoint)
## Files
- `fae_encoder.pt` — Spatial FAE encoder weights
- `feature_decoder.pt` — Spatial FAE decoder weights
- `training_state.pt` — Training metadata + feature normalization stats
- `fae_spatial.py` — Model architecture source code
## Usage
```python
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 36, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
```
|