FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling

Compression: 144x (576 β†’ 144 tokens, 1152 β†’ 32 channels)

Overview

Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.

Architecture

  • Encoder (28.05M params): Reshape to 2D β†’ Conv2d(1152, 1152, k=3, s=2, p=1) + GELU β†’ Self-attention + SwiGLU FFN β†’ Linear(1152β†’32) + VAE ΞΌ/logvar heads
  • Decoder (117.63M params): Linear(32β†’1152) β†’ 6-layer ViT (RoPE at 12Γ—12 grid, RMSNorm, SwiGLU) β†’ ConvTranspose2d upsample β†’ RMSNorm β†’ [B, 576, 1152]

Compression Details

Input Latent Compression
Tokens 576 (24Γ—24) 144 (12Γ—12) 4x spatial
Channels 1152 32 36x channel
Total values 663,552 4,608 144x

Results

Metric Value
Eval CosSim (feature reconstruction) 0.9776
VLM Chess MCQ Accuracy 3/20 (15%)
VLM Agreement with uncompressed baseline 17/20 (85%)

The s2 model exceeds the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.

Training

  • Dataset: Zebra-CoT Chess (19,983 train / 500 eval images)
  • Base model: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
  • Loss: MSE + Ξ²Β·KL (Ξ²=1e-6)
  • Optimizer: AdamW, lr=1e-4, weight_decay=0.05
  • Epochs: 100
  • Hardware: 2x NVIDIA L40S with DDP
  • Feature normalization: Per-dim mean/std (included in checkpoint)

Files

  • fae_encoder.pt β€” Spatial FAE encoder weights
  • feature_decoder.pt β€” Spatial FAE decoder weights
  • training_state.pt β€” Training metadata + feature normalization stats
  • fae_spatial.py β€” Model architecture source code

Usage

import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder

# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()

encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()

decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()

# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm)  # [B, 144, 32]
reconstructed = decoder(z)  # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train mk322/fae-spatial-s2-chess