File size: 5,581 Bytes
d0ee977 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
---
tags:
- mechanistic-interpretability
- transcoding
- bilinear
- pythia
- mlp
library_name: pytorch
license: mit
---
# Pythia-410m Bilinear MLP Transcoders
This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m).
## Overview
**Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.
## Model Architecture
- **Base Model**: EleutherAI/pythia-410m (24 layers)
- **Transcoder Type**: Bilinear (Hadamard Neural Network)
- **Architecture**: `output = W_left @ (x β (W_right @ x)) + bias`
- Input dimension: 1024 (d_model)
- Hidden dimension: 4096 (4x expansion)
- Output dimension: 1024 (d_model)
- **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
- **Dataset**: monology/pile-uncopyrighted
## Performance Summary
All 24 layers achieve >82% variance explained, with most layers >93%:
| Layer | Final FVU | Variance Explained | Notes |
|-------|-----------|-------------------|-------|
| 0 | 0.0075 | 99.2% | Best performance |
| 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate |
| 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance |
| 23 | 0.0259 | 97.4% | Second-best |
**Average across all layers**: 93.4% variance explained (FVU = 0.0657)
## Repository Structure
```
.
βββ layer_0/
β βββ transcoder_weights_l0_bilinear_muon_3000b.pt
β βββ config.yaml
βββ layer_1/
β βββ transcoder_weights_l1_bilinear_muon_3000b.pt
β βββ config.yaml
...
βββ layer_23/
β βββ transcoder_weights_l23_bilinear_muon_3000b.pt
β βββ config.yaml
βββ figures/
β βββ all_layers_comparison.png
β βββ training_curves_overlaid_layers_0_5.png
β βββ training_curves_overlaid_layers_6_11.png
β βββ training_curves_overlaid_layers_12_17.png
β βββ training_curves_overlaid_layers_18_23.png
βββ README.md
```
## Usage
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
# Load transcoder for layer 3
layer_idx = 3
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")
# Extract configuration
config = checkpoint['config']
print(f"Input dim: {config.n_inputs}")
print(f"Hidden dim: {config.n_hidden}")
print(f"Output dim: {config.n_outputs}")
# Reconstruct model (example - you'll need the Bilinear class)
class Bilinear(torch.nn.Module):
def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
super().__init__()
self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)
def forward(self, x):
right = self.W_right(x)
hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
return self.W_left(hadamard.sum(dim=-2))
transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
transcoder.load_state_dict(checkpoint['model_state_dict'])
transcoder.eval()
# Use transcoder to approximate MLP
with torch.no_grad():
# Get MLP input from layer 3
inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
mlp_input = outputs.hidden_states[layer_idx] # Before MLP
# Approximate MLP output with transcoder
transcoded_output = transcoder(mlp_input)
```
## Training Details
- **Optimizer**: Muon (momentum-based optimizer)
- **Learning Rate**: 0.02 (hardcoded for Muon)
- **Batch Size**: 512
- **Total Batches**: 3000 per layer
- **Training Time**: ~75 minutes per layer on A100
- **Normalization**: Per-batch z-score normalization
## Checkpoint Contents
Each checkpoint (`.pt` file) contains:
- `model_state_dict`: Model weights
- `optimizer_state_dict`: Optimizer state
- `config`: Configuration object with dimensions
- `mse_losses`: List of MSE losses per batch
- `variance_explained`: List of variance explained per batch
- `fvu_values`: List of FVU values per batch
- `layer_idx`: Layer index (0-23)
- `d_model`: Model dimension (1024)
## Key Findings
1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction
2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations
3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure
4. **Final layer is highly learnable** (97.4% VE)
This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.
## Citation
If you use these transcoders in your research, please cite:
```bibtex
@misc{pythia410m-bilinear-transcoders,
title={Bilinear MLP Transcoders for Pythia-410m},
author={[Your Name]},
year={2025},
publisher={Hugging Face},
url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
}
```
## License
MIT License
## Acknowledgments
- Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m)
- Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted)
|