mk322's picture
Upload README.md with huggingface_hub
06c54ee verified
---
license: apache-2.0
tags:
- vision
- compression
- feature-autoencoder
- qwen3-vl
datasets:
- multimodal-reasoning-lab/Zebra-CoT
---
# FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling
**Compression**: 576x (576 β†’ 36 tokens, 1152 β†’ 32 channels)
## Overview
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
## Architecture
- **Encoder** (39.99M params): Reshape to 2D β†’ 2Γ— [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] β†’ Self-attention + SwiGLU FFN β†’ Linear(1152β†’32) + VAE ΞΌ/logvar heads
- **Decoder** (138.86M params): Linear(32β†’1152) β†’ 6-layer ViT (RoPE at 6Γ—6 grid, RMSNorm, SwiGLU) β†’ 2Γ— [ConvTranspose2d upsample] β†’ RMSNorm β†’ [B, 576, 1152]
## Compression Details
| | Input | Latent | Compression |
|---|---|---|---|
| Tokens | 576 (24Γ—24) | 36 (6Γ—6) | 16x spatial |
| Channels | 1152 | 32 | 36x channel |
| **Total values** | **663,552** | **1,152** | **576x** |
## Results
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | **0.9648** |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 18/20 (90%) |
The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x).
## Training
- **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images)
- **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- **Loss**: MSE + Ξ²Β·KL (Ξ²=1e-6)
- **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05
- **Epochs**: 100
- **Hardware**: 2x NVIDIA L40S with DDP
- **Feature normalization**: Per-dim mean/std (included in checkpoint)
## Files
- `fae_encoder.pt` β€” Spatial FAE encoder weights
- `feature_decoder.pt` β€” Spatial FAE decoder weights
- `training_state.pt` β€” Training metadata + feature normalization stats
- `fae_spatial.py` β€” Model architecture source code
## Usage
```python
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 36, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
```