Upload folder using huggingface_hub
Browse files- README.md +143 -0
- sae_layer_00.pt +3 -0
- sae_layer_01.pt +3 -0
- sae_layer_02.pt +3 -0
- sae_layer_03.pt +3 -0
- sae_layer_04.pt +3 -0
- sae_layer_05.pt +3 -0
- sae_layer_06.pt +3 -0
- sae_layer_07.pt +3 -0
- sae_layer_08.pt +3 -0
- sae_layer_09.pt +3 -0
- sae_layer_10.pt +3 -0
- sae_layer_11.pt +3 -0
- sae_layer_12.pt +3 -0
- sae_layer_13.pt +3 -0
- sae_layer_14.pt +3 -0
- sae_layer_15.pt +3 -0
- sae_layer_16.pt +3 -0
- sae_layer_17.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FunctionGemma 270M SAE
|
| 3 |
+
language: en
|
| 4 |
+
tags:
|
| 5 |
+
- sparse-autoencoder
|
| 6 |
+
- sae
|
| 7 |
+
- interpretability
|
| 8 |
+
- functiongemma
|
| 9 |
+
- gemma
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# FunctionGemma 270M Sparse Autoencoders
|
| 14 |
+
|
| 15 |
+
Sparse Autoencoders (SAEs) trained on all 18 layers of [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it).
|
| 16 |
+
|
| 17 |
+
## Architecture
|
| 18 |
+
|
| 19 |
+
- **Base Model**: google/functiongemma-270m-it
|
| 20 |
+
- **Layers**: 18 (decoder-only)
|
| 21 |
+
- **Hidden Size**: 640
|
| 22 |
+
- **SAE Dimension**: 4096 (6.4x expansion)
|
| 23 |
+
- **Hook Point**: `self_attn.o_proj` (output projection of self-attention)
|
| 24 |
+
|
| 25 |
+
## Training
|
| 26 |
+
|
| 27 |
+
- **Epochs**: 5 per layer
|
| 28 |
+
- **Batch Size**: 1
|
| 29 |
+
- **Learning Rate**: 1e-4
|
| 30 |
+
- **Optimizer**: AdamW
|
| 31 |
+
- **Loss**: MSE + 0.01 * L1 regularization
|
| 32 |
+
- **Activation Clipping**: [-10, 10]
|
| 33 |
+
- **Gradient Clipping**: max_norm=1.0
|
| 34 |
+
|
| 35 |
+
## Checkpoints
|
| 36 |
+
|
| 37 |
+
Each checkpoint contains:
|
| 38 |
+
```python
|
| 39 |
+
{
|
| 40 |
+
"model_name": "google/functiongemma-270m-it",
|
| 41 |
+
"layer_idx": int,
|
| 42 |
+
"d_in": 640,
|
| 43 |
+
"d_sae": 4096,
|
| 44 |
+
"W_enc": torch.Tensor, # (640, 4096)
|
| 45 |
+
"b_enc": torch.Tensor, # (4096,)
|
| 46 |
+
"W_dec": torch.Tensor, # (4096, 640)
|
| 47 |
+
"b_dec": torch.Tensor, # (640,)
|
| 48 |
+
"history": {
|
| 49 |
+
"loss": [...],
|
| 50 |
+
"mse": [...],
|
| 51 |
+
"l0": [...]
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Usage
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
import torch
|
| 60 |
+
from huggingface_hub import hf_hub_download
|
| 61 |
+
|
| 62 |
+
# Load SAE for a specific layer
|
| 63 |
+
layer_idx = 0
|
| 64 |
+
ckpt_path = hf_hub_download(
|
| 65 |
+
"mindchain/functiongemma-270m-sae",
|
| 66 |
+
f"sae_layer_{layer_idx:02d}.pt"
|
| 67 |
+
)
|
| 68 |
+
sae = torch.load(ckpt_path, map_location="cpu")
|
| 69 |
+
|
| 70 |
+
# Use SAE
|
| 71 |
+
class JumpReLUSAE(torch.nn.Module):
|
| 72 |
+
def __init__(self, W_enc, b_enc, W_dec, b_dec):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.W_enc = torch.nn.Parameter(W_enc)
|
| 75 |
+
self.b_enc = torch.nn.Parameter(b_enc)
|
| 76 |
+
self.W_dec = torch.nn.Parameter(W_dec)
|
| 77 |
+
self.b_dec = torch.nn.Parameter(b_dec)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
batch, seq, d_in = x.shape
|
| 81 |
+
x_flat = x.view(-1, d_in)
|
| 82 |
+
pre_act = x_flat @ self.W_enc + self.b_enc
|
| 83 |
+
features = torch.relu(pre_act)
|
| 84 |
+
recon = features @ self.W_dec + self.b_dec
|
| 85 |
+
return recon.view(batch, seq, d_in), features.view(batch, seq, -1)
|
| 86 |
+
|
| 87 |
+
sae_model = JumpReLUSAE(
|
| 88 |
+
sae["W_enc"], sae["b_enc"],
|
| 89 |
+
sae["W_dec"], sae["b_dec"]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Get activations from FunctionGemma and encode
|
| 93 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 94 |
+
|
| 95 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 96 |
+
"google/functiongemma-270m-it",
|
| 97 |
+
torch_dtype=torch.bfloat16,
|
| 98 |
+
device_map="cuda"
|
| 99 |
+
)
|
| 100 |
+
tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it")
|
| 101 |
+
|
| 102 |
+
inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device)
|
| 103 |
+
|
| 104 |
+
# Hook to get activations
|
| 105 |
+
acts = []
|
| 106 |
+
def hook(module, inp, out):
|
| 107 |
+
acts.append(out[0].detach().float())
|
| 108 |
+
handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
_ = model(**inputs)
|
| 111 |
+
handle.remove()
|
| 112 |
+
|
| 113 |
+
# Run through SAE
|
| 114 |
+
recon, features = sae_model(acts[0])
|
| 115 |
+
print(f"Active features: {(features > 0).sum().item()}")
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## Training Results
|
| 119 |
+
|
| 120 |
+
| Layer | Final Loss | Final MSE | L0 |
|
| 121 |
+
|-------|------------|-----------|-----|
|
| 122 |
+
| 0 | 3.4457 | 3.1244 | 1225 |
|
| 123 |
+
| 1 | 2.0052 | 1.9042 | 1386 |
|
| 124 |
+
| 2 | 0.1182 | 0.0759 | 1546 |
|
| 125 |
+
| 3 | 0.1182 | 0.0758 | 3096 |
|
| 126 |
+
| 4 | 0.0361 | 0.0170 | 1635 |
|
| 127 |
+
| 5 | 0.0414 | 0.0351 | 399 |
|
| 128 |
+
| 6 | 0.0318 | 0.0138 | 1807 |
|
| 129 |
+
| 7 | 0.0877 | 0.0661 | 1120 |
|
| 130 |
+
| 8 | 0.0733 | 0.0445 | 1379 |
|
| 131 |
+
| 9 | 0.0561 | 0.0317 | 1569 |
|
| 132 |
+
| 10 | 0.0997 | 0.0852 | 591 |
|
| 133 |
+
| 11 | 0.0252 | 0.0097 | 3658 |
|
| 134 |
+
| 12 | 0.0565 | 0.0395 | 962 |
|
| 135 |
+
| 13 | 0.0924 | 0.0619 | 1403 |
|
| 136 |
+
| 14 | 0.2711 | 0.2504 | 709 |
|
| 137 |
+
| 15 | 0.1501 | 0.1062 | 1576 |
|
| 138 |
+
| 16 | 0.1670 | 0.1426 | 870 |
|
| 139 |
+
| 17 | 0.0385 | 0.0218 | 1470 |
|
| 140 |
+
|
| 141 |
+
## License
|
| 142 |
+
|
| 143 |
+
Apache 2.0
|
sae_layer_00.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68841d7a1a9e9a4dc690c0b0bd5296d1e540dda18f273970df465049b1de11a1
|
| 3 |
+
size 20992760
|
sae_layer_01.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3949f21bce6742ce35e5c258f855a7ae786fa8be7d4d8a4cec20784a589a989
|
| 3 |
+
size 20992760
|
sae_layer_02.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25608c0fed52bae9065bd35033428a010c0b79a00b16df3ec5c20132a98508be
|
| 3 |
+
size 20992760
|
sae_layer_03.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00999bab7630c4492cef40e47d6f2268ca8364fd3a6e3dd335c117239ac571ee
|
| 3 |
+
size 20992760
|
sae_layer_04.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a5d6402ded655b9e7fb62a8df33d45158af3ffef9e1e31c39c392f218425897
|
| 3 |
+
size 20992760
|
sae_layer_05.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13bda23f02320a65ee466b3f88c0d421e652377fa26cc61491a8587964b6f0ba
|
| 3 |
+
size 20992760
|
sae_layer_06.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ba06b5797ea7ce20522e8d04cbf93085db7333abf9839b8890d13b7971c0364
|
| 3 |
+
size 20992760
|
sae_layer_07.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d61060bda5f66d5d6c1d9d25900076fd0092b296412fe9b16485c5a7ff4846d
|
| 3 |
+
size 20992760
|
sae_layer_08.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8cb29be94c5490ad1b3d0ac9746089ab4567e54b27b4beb3a5c98fa57ac9cec9
|
| 3 |
+
size 20992760
|
sae_layer_09.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02e4e68c3e040c75b2ce73ef5fb052326926e104722a157319f4905c54c68594
|
| 3 |
+
size 20992760
|
sae_layer_10.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6037efd3b492a493858a1389f26c84a412cc37b1331ae6fbbdd6345bd83305f
|
| 3 |
+
size 20992760
|
sae_layer_11.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d1109e982c2239c84e70e260bbcf48638618256b4b542e38622a1f07016153f
|
| 3 |
+
size 20992760
|
sae_layer_12.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62e542f0f548b6027e4a3660ae868f1c44ff96231da7f8b347646ca31dbaa5fc
|
| 3 |
+
size 20992760
|
sae_layer_13.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67d6bcf3fa7d02c8bff0a484484d6e0dcd460b9ff4184063143410eb475240e0
|
| 3 |
+
size 20992760
|
sae_layer_14.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d5f2adabdad0c21d54a4828d4d1aa2848a939db40a97b2e609f5a8b55f1473d
|
| 3 |
+
size 20992760
|
sae_layer_15.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa91a0170e4ddf1ab82d9d6a582ab2941458a64f4e81367dd95950dfd65ff238
|
| 3 |
+
size 20992760
|
sae_layer_16.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb51ce70a5bbf895f469f9c10e1cce34b434cd95852708e2d39b08ce70e07588
|
| 3 |
+
size 20992760
|
sae_layer_17.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8f923993944b6b37f2cd8b4e86c18b649d3f18556709feaeea3ab7a38791316
|
| 3 |
+
size 20992760
|