File size: 5,581 Bytes
d0ee977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
---
tags:
- mechanistic-interpretability
- transcoding
- bilinear
- pythia
- mlp
library_name: pytorch
license: mit
---

# Pythia-410m Bilinear MLP Transcoders

This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m).

## Overview

**Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.

## Model Architecture

- **Base Model**: EleutherAI/pythia-410m (24 layers)
- **Transcoder Type**: Bilinear (Hadamard Neural Network)
- **Architecture**: `output = W_left @ (x βŠ™ (W_right @ x)) + bias`
  - Input dimension: 1024 (d_model)
  - Hidden dimension: 4096 (4x expansion)
  - Output dimension: 1024 (d_model)
- **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
- **Dataset**: monology/pile-uncopyrighted

## Performance Summary

All 24 layers achieve >82% variance explained, with most layers >93%:

| Layer | Final FVU | Variance Explained | Notes |
|-------|-----------|-------------------|-------|
| 0 | 0.0075 | 99.2% | Best performance |
| 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate |
| 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance |
| 23 | 0.0259 | 97.4% | Second-best |

**Average across all layers**: 93.4% variance explained (FVU = 0.0657)

## Repository Structure

```
.
β”œβ”€β”€ layer_0/
β”‚   β”œβ”€β”€ transcoder_weights_l0_bilinear_muon_3000b.pt
β”‚   └── config.yaml
β”œβ”€β”€ layer_1/
β”‚   β”œβ”€β”€ transcoder_weights_l1_bilinear_muon_3000b.pt
β”‚   └── config.yaml
...
β”œβ”€β”€ layer_23/
β”‚   β”œβ”€β”€ transcoder_weights_l23_bilinear_muon_3000b.pt
β”‚   └── config.yaml
β”œβ”€β”€ figures/
β”‚   β”œβ”€β”€ all_layers_comparison.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_0_5.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_6_11.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_12_17.png
β”‚   └── training_curves_overlaid_layers_18_23.png
└── README.md
```

## Usage

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load base model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")

# Load transcoder for layer 3
layer_idx = 3
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")

# Extract configuration
config = checkpoint['config']
print(f"Input dim: {config.n_inputs}")
print(f"Hidden dim: {config.n_hidden}")
print(f"Output dim: {config.n_outputs}")

# Reconstruct model (example - you'll need the Bilinear class)
class Bilinear(torch.nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
        super().__init__()
        self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
        self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)

    def forward(self, x):
        right = self.W_right(x)
        hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
        return self.W_left(hadamard.sum(dim=-2))

transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
transcoder.load_state_dict(checkpoint['model_state_dict'])
transcoder.eval()

# Use transcoder to approximate MLP
with torch.no_grad():
    # Get MLP input from layer 3
    inputs = tokenizer("Hello world", return_tensors="pt")
    outputs = model(**inputs, output_hidden_states=True)
    mlp_input = outputs.hidden_states[layer_idx]  # Before MLP

    # Approximate MLP output with transcoder
    transcoded_output = transcoder(mlp_input)
```

## Training Details

- **Optimizer**: Muon (momentum-based optimizer)
- **Learning Rate**: 0.02 (hardcoded for Muon)
- **Batch Size**: 512
- **Total Batches**: 3000 per layer
- **Training Time**: ~75 minutes per layer on A100
- **Normalization**: Per-batch z-score normalization

## Checkpoint Contents

Each checkpoint (`.pt` file) contains:
- `model_state_dict`: Model weights
- `optimizer_state_dict`: Optimizer state
- `config`: Configuration object with dimensions
- `mse_losses`: List of MSE losses per batch
- `variance_explained`: List of variance explained per batch
- `fvu_values`: List of FVU values per batch
- `layer_idx`: Layer index (0-23)
- `d_model`: Model dimension (1024)

## Key Findings

1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction
2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations
3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure
4. **Final layer is highly learnable** (97.4% VE)

This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.

## Citation

If you use these transcoders in your research, please cite:

```bibtex
@misc{pythia410m-bilinear-transcoders,
  title={Bilinear MLP Transcoders for Pythia-410m},
  author={[Your Name]},
  year={2025},
  publisher={Hugging Face},
  url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
}
```

## License

MIT License

## Acknowledgments

- Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m)
- Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted)