Temporal Crosscoders โ NLP (Gemma 2 2B)
Trained StackedSAE and TemporalCrosscoder dictionaries (expansion factor 8) on Gemma 2 2B activations extracted from FineWeb.
Architecture
- Base model: google/gemma-2-2b
- Dictionary size: 18432 (2304 x 8, expansion factor 8)
- Sparsity: TopK activation
- Loss: MSE reconstruction (no L1 penalty)
Results
| Model | Layer | k | T | Loss | FVU |
|---|---|---|---|---|---|
| txcdr | final_attn | 100 | 2 | 3.2184 | 0.2127 |
| txcdr | final_attn | 100 | 5 | 3.4018 | 0.224 |
| stacked_sae | final_attn | 100 | 2 | 3.4747 | 0.2308 |
| stacked_sae | final_attn | 100 | 5 | 3.4855 | 0.2358 |
| txcdr | final_attn | 50 | 2 | 4.3518 | 0.3051 |
| stacked_sae | final_attn | 50 | 2 | 4.4255 | 0.3044 |
| txcdr | final_attn | 50 | 5 | 4.5597 | 0.287 |
| stacked_sae | final_attn | 50 | 5 | 4.9816 | 0.3276 |
| txcdr | mid_attn | 100 | 2 | 5.7788 | 0.1555 |
| txcdr | mid_attn | 100 | 5 | 5.8233 | 0.1728 |
| stacked_sae | mid_attn | 100 | 2 | 6.8184 | 0.1985 |
| stacked_sae | mid_attn | 100 | 5 | 7.3851 | 0.2308 |
| txcdr | mid_attn | 50 | 5 | 7.8913 | 0.2313 |
| txcdr | mid_attn | 50 | 2 | 8.7045 | 0.2535 |
| stacked_sae | mid_attn | 50 | 2 | 9.4244 | 0.2551 |
| stacked_sae | mid_attn | 50 | 5 | 10.3315 | 0.3229 |
| txcdr | mid_res | 100 | 2 | 6360.9038 | 0.0537 |
| stacked_sae | mid_res | 100 | 2 | 6537.5220 | 0.0358 |
| stacked_sae | mid_res | 100 | 5 | 6742.7378 | 0.1235 |
| txcdr | mid_res | 100 | 5 | 7737.1909 | 0.0858 |
| txcdr | mid_res | 50 | 2 | 7841.6440 | 0.074 |
| stacked_sae | mid_res | 50 | 2 | 8230.1406 | 0.0448 |
| txcdr | mid_res | 50 | 5 | 8676.4756 | 0.0865 |
| stacked_sae | mid_res | 50 | 5 | 8857.8545 | 0.145 |
| txcdr | final_res | 100 | 2 | 98917.4219 | 0.1369 |
| stacked_sae | final_res | 100 | 5 | 105476.8281 | 0.1989 |
| stacked_sae | final_res | 100 | 2 | 106508.5625 | 0.1634 |
| txcdr | final_res | 50 | 2 | 115348.6406 | 0.1406 |
| txcdr | final_res | 100 | 5 | 118512.2266 | 0.2216 |
| stacked_sae | final_res | 50 | 2 | 126440.3750 | 0.2046 |
| stacked_sae | final_res | 50 | 5 | 127936.1641 | 0.2611 |
| txcdr | final_res | 50 | 5 | 135606.6094 | 0.2593 |
Usage
import torch
from temporal_crosscoders.models import StackedSAE, TemporalCrosscoder
# Load a checkpoint
state = torch.load("stacked_sae__mid_res__k50__T10.pt", weights_only=True)
model = StackedSAE(d_in=2304, d_sae=18432, T=10, k=50)
model.load_state_dict(state)
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support