Trained Sparse Autoencoders on Pythia 2.8B

I trained SAEs on the MLP_out activations of the Pythia 2.8B dataset. I trained using github.com/magikarp01/facts-sae, a fork of github.com/saprmarks/dictionary_learning designed for efficient multi-GPU (not yet multinode) training. I have checkpoints saved every 10k steps, but I have not uploaded them all: message me if you want more intermediate checkpoints.

The goal was originally to analyze these SAEs specifically to determine how well they contribute to performance on a Sports Facts dataset. I'm currently working on some other projects so I haven't actually had time to do this, but hopefully in the future some results might come out of these SAEs.

SAE Setup

  • Training Dataset: Uncopyrighted Pile, at monology/pile-uncopyrighted
  • Model: 32-layer Pythia 2.8B
  • Activation: MLP_out, so d_model of 2560
  • Layers Trained: 0, 1, 2, 15
  • Batch Size: 2048 for layer 15, 2560 for layers 0, 1, 2
  • Training Tokens: 1e9 for layers 15, 0, 2, slightly less than 2e9 for layer 1.
  • Training Steps: 4e5 for layers 0, 2, 5e5 for layer 15, 7.5e5 for layer 1
  • Dictionary Size: 16x activation, so 40960

Training Hyperparamaters

  • Learning Rate: 3e-4
  • Sparsity Penalty: 1e-3
  • Warmup Steps: 5000
  • Resample Steps: 50000
  • Optimizer: Constrained Adam
  • Scheduler: LambdaLR, linear warmup lr between 0 and warmup_steps

SAE Metrics

2.8b Layer 0 390k Steps 2.8b Layer 1 390k Steps 2.8b Layer 1 740k Steps 2.8b Layer 2 390k Steps 2.8b Layer 15 490k Steps

Thanks

Thanks to Nat Friedman/NFDG for letting me use H100s from the Andromeda Cluster during downtime, and thanks to Sam Marks/NDIF for the original SAE training repo and for helping me distribute the SAEs. Work done as a late part of my MATS training phase with Neel Nanda.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support