mindchain commited on
Commit
bcb5804
·
verified ·
1 Parent(s): 8983946

Upload t5gemma_sae_colab.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. t5gemma_sae_colab.ipynb +257 -0
t5gemma_sae_colab.ipynb ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# T5Gemma 2 SAE - Quick Start Guide\n",
8
+ "\n",
9
+ "This notebook shows how to use the **T5Gemma 2 Sparse Autoencoders** from [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers).\n",
10
+ "\n",
11
+ "## What are SAEs?\n",
12
+ "\n",
13
+ "Sparse Autoencoders (SAEs) help interpret what features a neural network has learned. They can be used for:\n",
14
+ "- **Mechanistic Interpretability** - Understanding model internals\n",
15
+ "- **Activation Steering** - Modifying model behavior \n",
16
+ "- **Feature Visualization** - Seeing what concepts each feature detects\n",
17
+ "\n",
18
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/README.md)"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## 1. Install Dependencies\n",
26
+ "\n",
27
+ "First, install the required libraries:"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "!pip install -q transformers torch huggingface_hub"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "## 2. Import Libraries"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import torch\n",
53
+ "from huggingface_hub import hf_hub_download\n",
54
+ "\n",
55
+ "print(\"Libraries imported successfully!\")\n",
56
+ "print(f\"PyTorch version: {torch.__version__}\")"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "## 3. Load a Trained SAE\n",
64
+ "\n",
65
+ "Load one of the 36 trained SAEs (18 encoder + 18 decoder layers)."
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "from huggingface_hub import hf_hub_download\n",
75
+ "\n",
76
+ "repo_id = \"mindchain/t5gemma2-sae-all-layers\"\n",
77
+ "\n",
78
+ "# Load Encoder Layer 0 SAE\n",
79
+ "sae_path = hf_hub_download(\n",
80
+ " repo_id=repo_id,\n",
81
+ " filename=\"encoder/sae_encoder_00.pt\"\n",
82
+ ")\n",
83
+ "\n",
84
+ "sae = torch.load(sae_path, map_location=\"cpu\")\n",
85
+ "\n",
86
+ "print(f\"SAE loaded from: {sae_path}\")\n",
87
+ "print(f\"Model: {sae['model_name']}\")\n",
88
+ "print(f\"Layer: {sae['layer_type']} {sae['layer_idx']}\")\n",
89
+ "print(f\"d_in: {sae['d_in']}, d_sae: {sae['d_sae']}\")\n",
90
+ "\n",
91
+ "# Show training history\n",
92
+ "if 'history' in sae:\n",
93
+ " print(f\"Final Loss: {sae['history']['loss'][-1]:.6f}\")\n",
94
+ " print(f\"Final L0: {sae['history']['l0'][-1]:.1f}\")"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "metadata": {},
100
+ "source": [
101
+ "## 4. SAE Forward Pass\n",
102
+ "\n",
103
+ "Define functions to run activations through the SAE."
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "def sae_encode(activations, sae):\n",
113
+ " \"\"\"Activations to Sparse Features\"\"\"\n",
114
+ " acts_f32 = activations.float()\n",
115
+ " return torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])\n",
116
+ "\n",
117
+ "def sae_decode(features, sae):\n",
118
+ " \"\"\"Sparse Features to Activations\"\"\"\n",
119
+ " return features @ sae['W_dec'] + sae['b_dec']\n",
120
+ "\n",
121
+ "def sae_forward(activations, sae):\n",
122
+ " \"\"\"Full SAE forward pass\"\"\"\n",
123
+ " features = sae_encode(activations, sae)\n",
124
+ " recon = sae_decode(features, sae)\n",
125
+ " return recon, features\n",
126
+ "\n",
127
+ "print(\"SAE functions defined!\")"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## 5. Test the SAE\n",
135
+ "\n",
136
+ "Create dummy activations and test reconstruction quality."
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "import torch.nn.functional as F\n",
146
+ "\n",
147
+ "# Create dummy activation\n",
148
+ "dummy_activations = torch.randn(1, 10, 640)\n",
149
+ "\n",
150
+ "# Run through SAE\n",
151
+ "recon, features = sae_forward(dummy_activations, sae)\n",
152
+ "\n",
153
+ "# Calculate metrics\n",
154
+ "mse = F.mse_loss(recon, dummy_activations).item()\n",
155
+ "cosine = F.cosine_similarity(\n",
156
+ " dummy_activations.flatten(), \n",
157
+ " recon.flatten(), \n",
158
+ " dim=0\n",
159
+ ").item()\n",
160
+ "l0 = (features > 0).sum().item()\n",
161
+ "\n",
162
+ "print(f\"Input shape: {dummy_activations.shape}\")\n",
163
+ "print(f\"Features shape: {features.shape}\")\n",
164
+ "print(f\"\\nReconstruction Quality:\")\n",
165
+ "print(f\" MSE: {mse:.6f}\")\n",
166
+ "print(f\" Cosine Similarity: {cosine:.4f}\")\n",
167
+ "print(f\" L0 (active features): {l0} / {features.shape[-1]}\")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "markdown",
172
+ "metadata": {},
173
+ "source": [
174
+ "## 6. All Available SAEs\n",
175
+ "\n",
176
+ "This repository contains **36 SAEs** in total:\n",
177
+ "\n",
178
+ "| Layer Type | Range | Count |\n",
179
+ "|------------|-------|-------|\n",
180
+ "| Encoder | 0-17 | 18 SAEs |\n",
181
+ "| Decoder | 0-17 | 18 SAEs |\n",
182
+ "| **Total** | - | **36 SAEs** |\n",
183
+ "\n",
184
+ "To load a different layer:\n",
185
+ "```python\n",
186
+ "# Encoder Layer 5\n",
187
+ "sae_path = hf_hub_download(\n",
188
+ " repo_id=\"mindchain/t5gemma2-sae-all-layers\",\n",
189
+ " filename=\"encoder/sae_encoder_05.pt\"\n",
190
+ ")\n",
191
+ "\n",
192
+ "# Decoder Layer 10\n",
193
+ "sae_path = hf_hub_download(\n",
194
+ " repo_id=\"mindchain/t5gemma2-sae-all-layers\",\n",
195
+ " filename=\"decoder/sae_decoder_10.pt\"\n",
196
+ ")\n",
197
+ "```"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "## 7. Usage with T5Gemma 2 Model\n",
205
+ "\n",
206
+ "To use SAEs with the actual T5Gemma 2 model:"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
216
+ "\n",
217
+ "# Load model\n",
218
+ "model = AutoModelForSeq2SeqLM.from_pretrained(\n",
219
+ " \"google/t5gemma-2-270m-270m\",\n",
220
+ " device_map=\"auto\"\n",
221
+ ")\n",
222
+ "tokenizer = AutoTokenizer.from_pretrained(\"google/t5gemma-2-270m-270m\")\n",
223
+ "\n",
224
+ "print(\"Model loaded!\")"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "markdown",
229
+ "metadata": {},
230
+ "source": [
231
+ "## Links\n",
232
+ "\n",
233
+ "- **HuggingFace Model**: [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers)\n",
234
+ "- **Base Model**: [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m)\n",
235
+ "- **SAELens**: [github.com/decoderesearch/SAELens](https://github.com/decoderesearch/SAELens)\n",
236
+ "- **Neuronpedia**: [neuronpedia.org](https://neuronpedia.org)\n",
237
+ "\n",
238
+ "---\n",
239
+ "\n",
240
+ "Trained by [mindchain](https://huggingface.co/mindchain) | December 2025"
241
+ ]
242
+ }
243
+ ],
244
+ "metadata": {
245
+ "kernelspec": {
246
+ "display_name": "Python 3",
247
+ "language": "python",
248
+ "name": "python3"
249
+ },
250
+ "language_info": {
251
+ "name": "python",
252
+ "version": "3.10.0"
253
+ }
254
+ },
255
+ "nbformat": 4,
256
+ "nbformat_minor": 0
257
+ }