You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Temporal Crosscoders โ€” NLP (Gemma 2 2B)

Trained StackedSAE and TemporalCrosscoder dictionaries (expansion factor 8) on Gemma 2 2B activations extracted from FineWeb.

Architecture

  • Base model: google/gemma-2-2b
  • Dictionary size: 18432 (2304 x 8, expansion factor 8)
  • Sparsity: TopK activation
  • Loss: MSE reconstruction (no L1 penalty)

Results

Model Layer k T Loss FVU
txcdr final_attn 100 2 3.2184 0.2127
txcdr final_attn 100 5 3.4018 0.224
stacked_sae final_attn 100 2 3.4747 0.2308
stacked_sae final_attn 100 5 3.4855 0.2358
txcdr final_attn 50 2 4.3518 0.3051
stacked_sae final_attn 50 2 4.4255 0.3044
txcdr final_attn 50 5 4.5597 0.287
stacked_sae final_attn 50 5 4.9816 0.3276
txcdr mid_attn 100 2 5.7788 0.1555
txcdr mid_attn 100 5 5.8233 0.1728
stacked_sae mid_attn 100 2 6.8184 0.1985
stacked_sae mid_attn 100 5 7.3851 0.2308
txcdr mid_attn 50 5 7.8913 0.2313
txcdr mid_attn 50 2 8.7045 0.2535
stacked_sae mid_attn 50 2 9.4244 0.2551
stacked_sae mid_attn 50 5 10.3315 0.3229
txcdr mid_res 100 2 6360.9038 0.0537
stacked_sae mid_res 100 2 6537.5220 0.0358
stacked_sae mid_res 100 5 6742.7378 0.1235
txcdr mid_res 100 5 7737.1909 0.0858
txcdr mid_res 50 2 7841.6440 0.074
stacked_sae mid_res 50 2 8230.1406 0.0448
txcdr mid_res 50 5 8676.4756 0.0865
stacked_sae mid_res 50 5 8857.8545 0.145
txcdr final_res 100 2 98917.4219 0.1369
stacked_sae final_res 100 5 105476.8281 0.1989
stacked_sae final_res 100 2 106508.5625 0.1634
txcdr final_res 50 2 115348.6406 0.1406
txcdr final_res 100 5 118512.2266 0.2216
stacked_sae final_res 50 2 126440.3750 0.2046
stacked_sae final_res 50 5 127936.1641 0.2611
txcdr final_res 50 5 135606.6094 0.2593

Usage

import torch
from temporal_crosscoders.models import StackedSAE, TemporalCrosscoder

# Load a checkpoint
state = torch.load("stacked_sae__mid_res__k50__T10.pt", weights_only=True)
model = StackedSAE(d_in=2304, d_sae=18432, T=10, k=50)
model.load_state_dict(state)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support