FAE-Spatial-S4: Feature Auto-Encoder with 4x Spatial Pooling

Compression: 576x (576 β†’ 36 tokens, 1152 β†’ 32 channels)

Overview

Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (4x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.

Architecture

  • Encoder (39.99M params): Reshape to 2D β†’ 2Γ— [Conv2d(1152, 1152, k=3, s=2, p=1) + GELU] β†’ Self-attention + SwiGLU FFN β†’ Linear(1152β†’32) + VAE ΞΌ/logvar heads
  • Decoder (138.86M params): Linear(32β†’1152) β†’ 6-layer ViT (RoPE at 6Γ—6 grid, RMSNorm, SwiGLU) β†’ 2Γ— [ConvTranspose2d upsample] β†’ RMSNorm β†’ [B, 576, 1152]

Compression Details

Input Latent Compression
Tokens 576 (24Γ—24) 36 (6Γ—6) 16x spatial
Channels 1152 32 36x channel
Total values 663,552 1,152 576x

Results

Metric Value
Eval CosSim (feature reconstruction) 0.9648
VLM Chess MCQ Accuracy 3/20 (15%)
VLM Agreement with uncompressed baseline 18/20 (90%)

The s4 model nearly matches the channel-only d32 baseline (CosSim 0.9670) with 16x more compression (576x total vs 36x).

Training

  • Dataset: Zebra-CoT Chess (19,983 train / 500 eval images)
  • Base model: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
  • Loss: MSE + Ξ²Β·KL (Ξ²=1e-6)
  • Optimizer: AdamW, lr=1e-4, weight_decay=0.05
  • Epochs: 100
  • Hardware: 2x NVIDIA L40S with DDP
  • Feature normalization: Per-dim mean/std (included in checkpoint)

Files

  • fae_encoder.pt β€” Spatial FAE encoder weights
  • feature_decoder.pt β€” Spatial FAE decoder weights
  • training_state.pt β€” Training metadata + feature normalization stats
  • fae_spatial.py β€” Model architecture source code

Usage

import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder

# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()

encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=4, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()

decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=4)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()

# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm)  # [B, 36, 32]
reconstructed = decoder(z)  # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train mk322/fae-spatial-s4-chess