|
|
--- |
|
|
tags: |
|
|
- mechanistic-interpretability |
|
|
- transcoding |
|
|
- bilinear |
|
|
- pythia |
|
|
- mlp |
|
|
library_name: pytorch |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
# Pythia-410m Bilinear MLP Transcoders |
|
|
|
|
|
This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m). |
|
|
|
|
|
## Overview |
|
|
|
|
|
**Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Base Model**: EleutherAI/pythia-410m (24 layers) |
|
|
- **Transcoder Type**: Bilinear (Hadamard Neural Network) |
|
|
- **Architecture**: `output = W_left @ (x β (W_right @ x)) + bias` |
|
|
- Input dimension: 1024 (d_model) |
|
|
- Hidden dimension: 4096 (4x expansion) |
|
|
- Output dimension: 1024 (d_model) |
|
|
- **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02) |
|
|
- **Dataset**: monology/pile-uncopyrighted |
|
|
|
|
|
## Performance Summary |
|
|
|
|
|
All 24 layers achieve >82% variance explained, with most layers >93%: |
|
|
|
|
|
| Layer | Final FVU | Variance Explained | Notes | |
|
|
|-------|-----------|-------------------|-------| |
|
|
| 0 | 0.0075 | 99.2% | Best performance | |
|
|
| 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate | |
|
|
| 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance | |
|
|
| 23 | 0.0259 | 97.4% | Second-best | |
|
|
|
|
|
**Average across all layers**: 93.4% variance explained (FVU = 0.0657) |
|
|
|
|
|
## Repository Structure |
|
|
|
|
|
``` |
|
|
. |
|
|
βββ layer_0/ |
|
|
β βββ transcoder_weights_l0_bilinear_muon_3000b.pt |
|
|
β βββ config.yaml |
|
|
βββ layer_1/ |
|
|
β βββ transcoder_weights_l1_bilinear_muon_3000b.pt |
|
|
β βββ config.yaml |
|
|
... |
|
|
βββ layer_23/ |
|
|
β βββ transcoder_weights_l23_bilinear_muon_3000b.pt |
|
|
β βββ config.yaml |
|
|
βββ figures/ |
|
|
β βββ all_layers_comparison.png |
|
|
β βββ training_curves_overlaid_layers_0_5.png |
|
|
β βββ training_curves_overlaid_layers_6_11.png |
|
|
β βββ training_curves_overlaid_layers_12_17.png |
|
|
β βββ training_curves_overlaid_layers_18_23.png |
|
|
βββ README.md |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
# Load base model |
|
|
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m") |
|
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m") |
|
|
|
|
|
# Load transcoder for layer 3 |
|
|
layer_idx = 3 |
|
|
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt") |
|
|
|
|
|
# Extract configuration |
|
|
config = checkpoint['config'] |
|
|
print(f"Input dim: {config.n_inputs}") |
|
|
print(f"Hidden dim: {config.n_hidden}") |
|
|
print(f"Output dim: {config.n_outputs}") |
|
|
|
|
|
# Reconstruct model (example - you'll need the Bilinear class) |
|
|
class Bilinear(torch.nn.Module): |
|
|
def __init__(self, n_inputs, n_hidden, n_outputs, bias=True): |
|
|
super().__init__() |
|
|
self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias) |
|
|
self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
right = self.W_right(x) |
|
|
hadamard = x.unsqueeze(-1) * right.unsqueeze(-2) |
|
|
return self.W_left(hadamard.sum(dim=-2)) |
|
|
|
|
|
transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias) |
|
|
transcoder.load_state_dict(checkpoint['model_state_dict']) |
|
|
transcoder.eval() |
|
|
|
|
|
# Use transcoder to approximate MLP |
|
|
with torch.no_grad(): |
|
|
# Get MLP input from layer 3 |
|
|
inputs = tokenizer("Hello world", return_tensors="pt") |
|
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
mlp_input = outputs.hidden_states[layer_idx] # Before MLP |
|
|
|
|
|
# Approximate MLP output with transcoder |
|
|
transcoded_output = transcoder(mlp_input) |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Optimizer**: Muon (momentum-based optimizer) |
|
|
- **Learning Rate**: 0.02 (hardcoded for Muon) |
|
|
- **Batch Size**: 512 |
|
|
- **Total Batches**: 3000 per layer |
|
|
- **Training Time**: ~75 minutes per layer on A100 |
|
|
- **Normalization**: Per-batch z-score normalization |
|
|
|
|
|
## Checkpoint Contents |
|
|
|
|
|
Each checkpoint (`.pt` file) contains: |
|
|
- `model_state_dict`: Model weights |
|
|
- `optimizer_state_dict`: Optimizer state |
|
|
- `config`: Configuration object with dimensions |
|
|
- `mse_losses`: List of MSE losses per batch |
|
|
- `variance_explained`: List of variance explained per batch |
|
|
- `fvu_values`: List of FVU values per batch |
|
|
- `layer_idx`: Layer index (0-23) |
|
|
- `d_model`: Model dimension (1024) |
|
|
|
|
|
## Key Findings |
|
|
|
|
|
1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction |
|
|
2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations |
|
|
3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure |
|
|
4. **Final layer is highly learnable** (97.4% VE) |
|
|
|
|
|
This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use these transcoders in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{pythia410m-bilinear-transcoders, |
|
|
title={Bilinear MLP Transcoders for Pythia-410m}, |
|
|
author={[Your Name]}, |
|
|
year={2025}, |
|
|
publisher={Hugging Face}, |
|
|
url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m) |
|
|
- Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted) |
|
|
|