--- license: apache-2.0 tags: - vision - compression - feature-autoencoder - qwen3-vl datasets: - multimodal-reasoning-lab/Zebra-CoT --- # FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling **Compression**: 576x (576 → 36 tokens, 1152 → 32 channels) ## Overview Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT. ## Architecture - **Encoder** (39.99M params): Reshape to 2D → 2× [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads - **Decoder** (138.86M params): Linear(32→1152) → 6-layer ViT (RoPE at 6×6 grid, RMSNorm, SwiGLU) → 2× [ConvTranspose2d upsample] → RMSNorm → [B, 576, 1152] ## Compression Details | | Input | Latent | Compression | |---|---|---|---| | Tokens | 576 (24×24) | 36 (6×6) | 16x spatial | | Channels | 1152 | 32 | 36x channel | | **Total values** | **663,552** | **1,152** | **576x** | ## Results | Metric | Value | |---|---| | Eval CosSim (feature reconstruction) | **0.9648** | | VLM Chess MCQ Accuracy | 3/20 (15%) | | VLM Agreement with uncompressed baseline | 18/20 (90%) | The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x). ## Training - **Dataset**: Zebra-CoT Chess (19,983 train / 500 eval images) - **Base model**: Qwen3-VL-8B-Instruct (ViT embed_dim=1152) - **Loss**: MSE + β·KL (β=1e-6) - **Optimizer**: AdamW, lr=1e-4, weight_decay=0.05 - **Epochs**: 100 - **Hardware**: 2x NVIDIA L40S with DDP - **Feature normalization**: Per-dim mean/std (included in checkpoint) ## Files - `fae_encoder.pt` — Spatial FAE encoder weights - `feature_decoder.pt` — Spatial FAE decoder weights - `training_state.pt` — Training metadata + feature normalization stats - `fae_spatial.py` — Model architecture source code ## Usage ```python import torch from fae_spatial import FAESpatialEncoder, FAESpatialDecoder # Load checkpoint state = torch.load("training_state.pt", map_location="cpu") feat_mean = state["feat_mean"].cuda() feat_std = state["feat_std"].cuda() encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True) encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu")) encoder = encoder.cuda().eval() decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4) decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu")) decoder = decoder.cuda().eval() # Compress ViT features [B, 576, 1152] vit_features_norm = (vit_features - feat_mean) / feat_std z, mu, logvar = encoder(vit_features_norm) # [B, 36, 32] reconstructed = decoder(z) # [B, 576, 1152] reconstructed = reconstructed * feat_std + feat_mean ```