hnn_transcoders / README.md
Elriggs's picture
Upload README.md with huggingface_hub
d0ee977 verified
---
tags:
- mechanistic-interpretability
- transcoding
- bilinear
- pythia
- mlp
library_name: pytorch
license: mit
---
# Pythia-410m Bilinear MLP Transcoders
This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m).
## Overview
**Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.
## Model Architecture
- **Base Model**: EleutherAI/pythia-410m (24 layers)
- **Transcoder Type**: Bilinear (Hadamard Neural Network)
- **Architecture**: `output = W_left @ (x βŠ™ (W_right @ x)) + bias`
- Input dimension: 1024 (d_model)
- Hidden dimension: 4096 (4x expansion)
- Output dimension: 1024 (d_model)
- **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
- **Dataset**: monology/pile-uncopyrighted
## Performance Summary
All 24 layers achieve >82% variance explained, with most layers >93%:
| Layer | Final FVU | Variance Explained | Notes |
|-------|-----------|-------------------|-------|
| 0 | 0.0075 | 99.2% | Best performance |
| 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate |
| 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance |
| 23 | 0.0259 | 97.4% | Second-best |
**Average across all layers**: 93.4% variance explained (FVU = 0.0657)
## Repository Structure
```
.
β”œβ”€β”€ layer_0/
β”‚ β”œβ”€β”€ transcoder_weights_l0_bilinear_muon_3000b.pt
β”‚ └── config.yaml
β”œβ”€β”€ layer_1/
β”‚ β”œβ”€β”€ transcoder_weights_l1_bilinear_muon_3000b.pt
β”‚ └── config.yaml
...
β”œβ”€β”€ layer_23/
β”‚ β”œβ”€β”€ transcoder_weights_l23_bilinear_muon_3000b.pt
β”‚ └── config.yaml
β”œβ”€β”€ figures/
β”‚ β”œβ”€β”€ all_layers_comparison.png
β”‚ β”œβ”€β”€ training_curves_overlaid_layers_0_5.png
β”‚ β”œβ”€β”€ training_curves_overlaid_layers_6_11.png
β”‚ β”œβ”€β”€ training_curves_overlaid_layers_12_17.png
β”‚ └── training_curves_overlaid_layers_18_23.png
└── README.md
```
## Usage
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
# Load transcoder for layer 3
layer_idx = 3
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")
# Extract configuration
config = checkpoint['config']
print(f"Input dim: {config.n_inputs}")
print(f"Hidden dim: {config.n_hidden}")
print(f"Output dim: {config.n_outputs}")
# Reconstruct model (example - you'll need the Bilinear class)
class Bilinear(torch.nn.Module):
def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
super().__init__()
self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)
def forward(self, x):
right = self.W_right(x)
hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
return self.W_left(hadamard.sum(dim=-2))
transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
transcoder.load_state_dict(checkpoint['model_state_dict'])
transcoder.eval()
# Use transcoder to approximate MLP
with torch.no_grad():
# Get MLP input from layer 3
inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
mlp_input = outputs.hidden_states[layer_idx] # Before MLP
# Approximate MLP output with transcoder
transcoded_output = transcoder(mlp_input)
```
## Training Details
- **Optimizer**: Muon (momentum-based optimizer)
- **Learning Rate**: 0.02 (hardcoded for Muon)
- **Batch Size**: 512
- **Total Batches**: 3000 per layer
- **Training Time**: ~75 minutes per layer on A100
- **Normalization**: Per-batch z-score normalization
## Checkpoint Contents
Each checkpoint (`.pt` file) contains:
- `model_state_dict`: Model weights
- `optimizer_state_dict`: Optimizer state
- `config`: Configuration object with dimensions
- `mse_losses`: List of MSE losses per batch
- `variance_explained`: List of variance explained per batch
- `fvu_values`: List of FVU values per batch
- `layer_idx`: Layer index (0-23)
- `d_model`: Model dimension (1024)
## Key Findings
1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction
2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations
3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure
4. **Final layer is highly learnable** (97.4% VE)
This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.
## Citation
If you use these transcoders in your research, please cite:
```bibtex
@misc{pythia410m-bilinear-transcoders,
title={Bilinear MLP Transcoders for Pythia-410m},
author={[Your Name]},
year={2025},
publisher={Hugging Face},
url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
}
```
## License
MIT License
## Acknowledgments
- Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m)
- Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted)