pauvanbr's picture
Update ML Intern artifact metadata
5e1b973 verified
---
tags:
- ml-intern
---
# SAE Encoder Embeddings: End-to-End Sparse Autoencoder Bottleneck for Retrieval
> **Status**: Research & Architecture Design Phase
> **Goal**: Build the first encoder-only embedding model where the representation layer IS a Sparse Autoencoder, trained end-to-end with contrastive loss.
## 🎯 What This Is
A novel embedding architecture that combines:
- **ModernBERT** backbone (SOTA encoder-only with LLM innovations)
- **TopK Sparse Autoencoder** as the embedding bottleneck layer
- **End-to-end contrastive training** (not post-hoc SAE on frozen embeddings)
This produces embeddings that are simultaneously:
1. **Interpretable** β€” each active dimension corresponds to a learned semantic concept
2. **Steerable** β€” suppress/amplify specific features to control retrieval
3. **Sparse-indexable** β€” native sparse vector search (inverted index, not ANN)
4. **Competitive** β€” trained with modern contrastive objectives + hard negatives
## πŸ”¬ Why This Is Novel
| Approach | Training | Interpretable? | Sparse-native? | End-to-end? |
|----------|----------|---------------|----------------|-------------|
| Dense bi-encoder (e.g., E5, GTE) | Contrastive | ❌ | ❌ | βœ… |
| SPLADE | Distillation + regularizer | ⚠️ (vocab-tied) | βœ… | βœ… |
| Post-hoc SAE on embeddings | Reconstruction only | βœ… | βœ… | ❌ |
| CSR (Beyond Matryoshka) | Contrastive + recon (frozen backbone) | βœ… | βœ… | ❌ (backbone frozen) |
| SPLARE (Mar 2026) | Distillation (KL from cross-encoder) | βœ… | βœ… | ⚠️ (pretrained SAE, frozen LLM) |
| **Ours (this project)** | **Contrastive + recon + FLOPS reg** | βœ… | βœ… | βœ… **(backbone + SAE jointly)** |
**Key differentiator**: All prior SAE-for-retrieval work either freezes the backbone or freezes the SAE. We train both jointly, meaning the backbone learns to produce representations that are *optimally decomposable* into sparse interpretable features.
## πŸ“‚ Repository Structure
```
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ ARCHITECTURE.md # Detailed architecture design
β”œβ”€β”€ PAPERS.md # Papers bibliography + key findings
β”œβ”€β”€ TRAINING_RECIPE.md # Full training recipe with hyperparameters
β”œβ”€β”€ src/ # (future) Implementation code
β”‚ β”œβ”€β”€ model.py # SAE bottleneck + ModernBERT
β”‚ β”œβ”€β”€ loss.py # Combined loss functions
β”‚ └── train.py # Training script
└── experiments/ # (future) Training logs and results
```
## πŸ—οΈ Architecture Overview
```
Input text
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ ModernBERT-base (768-dim) β”‚ ← Backbone (trainable)
β”‚ - RoPE positional embeddings β”‚
β”‚ - FlashAttention 2 β”‚
β”‚ - GeGLU activations β”‚
β”‚ - Alternating local/global attnβ”‚
β”‚ - 8192 token context β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ mean-pool β†’ v ∈ ℝ^768
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TopK Sparse Autoencoder β”‚ ← SAE Bottleneck (trainable)
β”‚ β”‚
β”‚ Encoder: z = TopK(W_enc(v-b) + b_enc)
β”‚ z ∈ ℝ^16384, ||z||_0 = k (32-128 active)
β”‚ β”‚
β”‚ Decoder: vΜ‚ = W_decΒ·z + b β”‚ ← For reconstruction loss only
β”‚ (not used at inference)β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
z (sparse embedding)
Used for retrieval via sparse dot product
```
## πŸ“Š Key Design Decisions
### Why TopK (not L1)?
- **Exact control** of sparsity (k active features guaranteed)
- **No shrinkage bias** β€” L1 pushes all activations toward 0
- **Better Pareto frontier** at scale (OpenAI, arxiv:2406.04093)
- **Dead latent prevention** via AuxK loss
### Why End-to-End (not frozen backbone)?
- Backbone learns to produce **optimally decomposable** representations
- CSR/SPLARE show frozen backbone limits retrieval performance
- Joint training enables the SAE to develop features that are *useful for retrieval*, not just reconstructive
### Why ModernBERT?
- SOTA encoder-only architecture (surpasses BERT/RoBERTa/DeBERTa)
- LLM innovations: RoPE, FlashAttn, GeGLU, 8k context
- 768-dim base / 1024-dim large β€” good SAE input dimensions
- Hardware-aware design (efficient on T4/A10/A100)
## πŸ”— Key References
| Paper | ArXiv | Relevance |
|-------|-------|-----------|
| ModernBERT | [2412.13663](https://arxiv.org/abs/2412.13663) | Backbone architecture |
| TopK SAE (OpenAI) | [2406.04093](https://arxiv.org/abs/2406.04093) | SAE architecture + dead latent prevention |
| CSR (Beyond Matryoshka) | [2503.01776](https://arxiv.org/abs/2503.01776) | Contrastive sparse coding framework |
| SPLARE | [2603.13277](https://arxiv.org/abs/2603.13277) | SAE for retrieval (closest prior work) |
| SPLADE v2 | [2109.10086](https://arxiv.org/abs/2109.10086) | FLOPS regularizer for sparse retrieval |
| EmbeddingGemma | [2509.20354](https://arxiv.org/abs/2509.20354) | GOR spread-out regularizer |
| Nomic Embed v2 MoE | [2502.07972](https://arxiv.org/abs/2502.07972) | MoE encoder embeddings |
| Ettin | [2507.11412](https://arxiv.org/abs/2507.11412) | Encoder vs Decoder comparison |
| Theoretical Limits | [2508.21038](https://arxiv.org/abs/2508.21038) | Why single-vector has capacity limits |
| Disentangling Embeddings (SAE) | [2408.00657](https://arxiv.org/abs/2408.00657) | SAE interpretability for embeddings |
| Interpretable Embed SAE | [2512.10092](https://arxiv.org/abs/2512.10092) | SAE data analysis toolkit |
| Hypencoder | [2502.05364](https://arxiv.org/abs/2502.05364) | Beyond dot-product retrieval |
| RouterRetriever | [2409.02685](https://arxiv.org/abs/2409.02685) | Router + expert models pattern |
## ⚑ Quick Links
- **Backbone model**: [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
- **Training data**: [sentence-transformers/msmarco-bm25](https://huggingface.co/datasets/sentence-transformers/msmarco-bm25) + [sentence-transformers/all-nli](https://huggingface.co/datasets/sentence-transformers/all-nli)
- **Evaluation**: MTEB benchmark
- **SAE reference impl**: [OpenAI sparse_autoencoder](https://github.com/openai/sparse_autoencoder)
- **SPLARE (closest prior)**: [arxiv:2603.13277](https://arxiv.org/abs/2603.13277)
- **CSR code**: [github.com/Mhz1y/CSR](https://github.com/Mhz1y/CSR)
## πŸ“ˆ Expected Outcomes
1. **Retrieval quality**: Competitive with dense ModernBERT embeddings on MTEB retrieval tasks
2. **Interpretability**: Each active SAE feature maps to a human-interpretable concept
3. **Steerability**: Users can boost/suppress features to control search results
4. **Efficiency**: Sparse dot product with inverted index β€” potentially faster than dense ANN
5. **Novel contribution**: First end-to-end jointly-trained SAE embedding encoder
<!-- ml-intern-provenance -->
## Generated by ML Intern
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "pauvanbr/sae-encoder-embeddings-research"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.