--- license: apache-2.0 tags: - vision - compression - feature-autoencoder - qwen3-vl datasets: - multimodal-reasoning-lab/Zebra-CoT --- # FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling **Compression**: 144x (576 → 144 tokens, 1152 → 32 channels) ## Overview Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT. ## Architecture - **Encoder** (28.05M params): Reshape to 2D → Conv2d(1152, 1152, k=3, s=2, p=1) + GELU → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads - **Decoder** (117.63M params): Linear(32→1152) → 6-layer ViT (RoPE at 12×12 grid, RMSNorm, SwiGLU) → ConvTranspose2d upsample → RMSNorm → [B, 576, 1152] ## Compression Details | | Input | Latent | Compression | |---|---|---|---| | Tokens | 576 (24×24) | 144 (12×12) | 4x spatial | | Channels | 1152 | 32 | 36x channel | | **Total values** | **663,552** | **4,608** | **144x** | ## Results | Metric | Value | |---|---| | Eval CosSim (feature reconstruction) | **0.9776** | | VLM Chess MCQ Accuracy | 3/20 (15%) | | VLM Agreement with uncompressed baseline | 17/20 (85%) | The s2 model **exceeds** the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed. ## Training - **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images) - **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152) - **Loss**: MSE + β·KL (β=1e-6) - **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05 - **Epochs**: 100 - **Hardware**: 2x NVIDIA L40S with DDP - **Feature normalization**: Per-dim mean/std (included in checkpoint) ## Files - `fae_encoder.pt` — Spatial FAE encoder weights - `feature_decoder.pt` — Spatial FAE decoder weights - `training_state.pt` — Training metadata + feature normalization stats - `fae_spatial.py` — Model architecture source code ## Usage ```python import torch from fae_spatial import FAESpatialEncoder, FAESpatialDecoder # Load checkpoint state = torch.load("training_state.pt", map_location="cpu") feat_mean = state["feat_mean"].cuda() feat_std = state["feat_std"].cuda() encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True) encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu")) encoder = encoder.cuda().eval() decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2) decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu")) decoder = decoder.cuda().eval() # Compress ViT features [B, 576, 1152] vit_features_norm = (vit_features - feat_mean) / feat_std z, mu, logvar = encoder(vit_features_norm) # [B, 144, 32] reconstructed = decoder(z) # [B, 576, 1152] reconstructed = reconstructed * feat_std + feat_mean ```