File size: 2,905 Bytes
ff01597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
license: apache-2.0
tags:
  - vision
  - compression
  - feature-autoencoder
  - qwen3-vl
datasets:
  - multimodal-reasoning-lab/Zebra-CoT
---

# FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling

**Compression**: 144x (576 → 144 tokens, 1152 → 32 channels)

## Overview

Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.

## Architecture

- **Encoder** (28.05M params): Reshape to 2D → Conv2d(1152, 1152, k=3, s=2, p=1) + GELU → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads
- **Decoder** (117.63M params): Linear(32→1152) → 6-layer ViT (RoPE at 12×12 grid, RMSNorm, SwiGLU) → ConvTranspose2d upsample → RMSNorm → [B, 576, 1152]

## Compression Details

| | Input | Latent | Compression |
|---|---|---|---|
| Tokens | 576 (24×24) | 144 (12×12) | 4x spatial |
| Channels | 1152 | 32 | 36x channel |
| **Total values** | **663,552** | **4,608** | **144x** |

## Results

| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | **0.9776** |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 17/20 (85%) |

The s2 model **exceeds** the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.

## Training

- **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images)
- **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
- **Loss**: MSE + β·KL (β=1e-6)
- **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05
- **Epochs**: 100
- **Hardware**: 2x NVIDIA L40S with DDP
- **Feature normalization**: Per-dim mean/std (included in checkpoint)

## Files

- `fae_encoder.pt` — Spatial FAE encoder weights
- `feature_decoder.pt` — Spatial FAE decoder weights  
- `training_state.pt` — Training metadata + feature normalization stats
- `fae_spatial.py` — Model architecture source code

## Usage

```python
import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder

# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()

encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()

decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()

# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm)  # [B, 144, 32]
reconstructed = decoder(z)  # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
```