mk322's picture
Upload README.md with huggingface_hub
ff01597 verified
metadata
license: apache-2.0
tags:
  - vision
  - compression
  - feature-autoencoder
  - qwen3-vl
datasets:
  - multimodal-reasoning-lab/Zebra-CoT

FAE-Spatial-S2: Feature Auto-Encoder with 2x Spatial Pooling

Compression: 144x (576 → 144 tokens, 1152 → 32 channels)

Overview

Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.

Architecture

  • Encoder (28.05M params): Reshape to 2D → Conv2d(1152, 1152, k=3, s=2, p=1) + GELU → Self-attention + SwiGLU FFN → Linear(1152→32) + VAE μ/logvar heads
  • Decoder (117.63M params): Linear(32→1152) → 6-layer ViT (RoPE at 12×12 grid, RMSNorm, SwiGLU) → ConvTranspose2d upsample → RMSNorm → [B, 576, 1152]

Compression Details

Input Latent Compression
Tokens 576 (24×24) 144 (12×12) 4x spatial
Channels 1152 32 36x channel
Total values 663,552 4,608 144x

Results

Metric Value
Eval CosSim (feature reconstruction) 0.9776
VLM Chess MCQ Accuracy 3/20 (15%)
VLM Agreement with uncompressed baseline 17/20 (85%)

The s2 model exceeds the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.

Training

  • Dataset: Zebra-CoT Chess (19,983 train / 500 eval images)
  • Base model: Qwen3-VL-8B-Instruct (ViT embed_dim=1152)
  • Loss: MSE + β·KL (β=1e-6)
  • Optimizer: AdamW, lr=1e-4, weight_decay=0.05
  • Epochs: 100
  • Hardware: 2x NVIDIA L40S with DDP
  • Feature normalization: Per-dim mean/std (included in checkpoint)

Files

  • fae_encoder.pt — Spatial FAE encoder weights
  • feature_decoder.pt — Spatial FAE decoder weights
  • training_state.pt — Training metadata + feature normalization stats
  • fae_spatial.py — Model architecture source code

Usage

import torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder

# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()

encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()

decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()

# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm)  # [B, 144, 32]
reconstructed = decoder(z)  # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean