File size: 2,905 Bytes
ff01597 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | ---
license: apache-2.0
tags:
- vision
- compression
- feature-autoencoder
- qwen3-vl
datasets:
- multimodal-reasoning-lab/Zebra-CoT
---
# FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling
**Compression**: 144x (576 → 144 tokens, 1152 → 32 channels)
## Overview
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
## Architecture
- **Encoder** (28.05M params): Reshape to 2D → Conv2d(1152, 1152, k=3, s=2, p=1) + GELU → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads
- **Decoder** (117.63M params): Linear(32→1152) → 6-layer ViT (RoPE at 12×12 grid, RMSNorm, SwiGLU) → ConvTranspose2d upsample → RMSNorm → [B, 576, 1152]
## Compression Details
| | Input | Latent | Compression |
|---|---|---|---|
| Tokens | 576 (24×24) | 144 (12×12) | 4x spatial |
| Channels | 1152 | 32 | 36x channel |
| **Total values** | **663,552** | **4,608** | **144x** |
## Results
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | **0.9776** |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 17/20 (85%) |
The s2 model **exceeds** the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.
## Training
- **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images)
- **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- **Loss**: MSE + β·KL (β=1e-6)
- **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05
- **Epochs**: 100
- **Hardware**: 2x NVIDIA L40S with DDP
- **Feature normalization**: Per-dim mean/std (included in checkpoint)
## Files
- `fae_encoder.pt` — Spatial FAE encoder weights
- `feature_decoder.pt` — Spatial FAE decoder weights
- `training_state.pt` — Training metadata + feature normalization stats
- `fae_spatial.py` — Model architecture source code
## Usage
```python
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 144, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
```
|