mindchain commited on
Commit
c8cdad1
·
verified ·
1 Parent(s): f35d776

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FunctionGemma 270M SAE
3
+ language: en
4
+ tags:
5
+ - sparse-autoencoder
6
+ - sae
7
+ - interpretability
8
+ - functiongemma
9
+ - gemma
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # FunctionGemma 270M Sparse Autoencoders
14
+
15
+ Sparse Autoencoders (SAEs) trained on all 18 layers of [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it).
16
+
17
+ ## Architecture
18
+
19
+ - **Base Model**: google/functiongemma-270m-it
20
+ - **Layers**: 18 (decoder-only)
21
+ - **Hidden Size**: 640
22
+ - **SAE Dimension**: 4096 (6.4x expansion)
23
+ - **Hook Point**: `self_attn.o_proj` (output projection of self-attention)
24
+
25
+ ## Training
26
+
27
+ - **Epochs**: 5 per layer
28
+ - **Batch Size**: 1
29
+ - **Learning Rate**: 1e-4
30
+ - **Optimizer**: AdamW
31
+ - **Loss**: MSE + 0.01 * L1 regularization
32
+ - **Activation Clipping**: [-10, 10]
33
+ - **Gradient Clipping**: max_norm=1.0
34
+
35
+ ## Checkpoints
36
+
37
+ Each checkpoint contains:
38
+ ```python
39
+ {
40
+ "model_name": "google/functiongemma-270m-it",
41
+ "layer_idx": int,
42
+ "d_in": 640,
43
+ "d_sae": 4096,
44
+ "W_enc": torch.Tensor, # (640, 4096)
45
+ "b_enc": torch.Tensor, # (4096,)
46
+ "W_dec": torch.Tensor, # (4096, 640)
47
+ "b_dec": torch.Tensor, # (640,)
48
+ "history": {
49
+ "loss": [...],
50
+ "mse": [...],
51
+ "l0": [...]
52
+ }
53
+ }
54
+ ```
55
+
56
+ ## Usage
57
+
58
+ ```python
59
+ import torch
60
+ from huggingface_hub import hf_hub_download
61
+
62
+ # Load SAE for a specific layer
63
+ layer_idx = 0
64
+ ckpt_path = hf_hub_download(
65
+ "mindchain/functiongemma-270m-sae",
66
+ f"sae_layer_{layer_idx:02d}.pt"
67
+ )
68
+ sae = torch.load(ckpt_path, map_location="cpu")
69
+
70
+ # Use SAE
71
+ class JumpReLUSAE(torch.nn.Module):
72
+ def __init__(self, W_enc, b_enc, W_dec, b_dec):
73
+ super().__init__()
74
+ self.W_enc = torch.nn.Parameter(W_enc)
75
+ self.b_enc = torch.nn.Parameter(b_enc)
76
+ self.W_dec = torch.nn.Parameter(W_dec)
77
+ self.b_dec = torch.nn.Parameter(b_dec)
78
+
79
+ def forward(self, x):
80
+ batch, seq, d_in = x.shape
81
+ x_flat = x.view(-1, d_in)
82
+ pre_act = x_flat @ self.W_enc + self.b_enc
83
+ features = torch.relu(pre_act)
84
+ recon = features @ self.W_dec + self.b_dec
85
+ return recon.view(batch, seq, d_in), features.view(batch, seq, -1)
86
+
87
+ sae_model = JumpReLUSAE(
88
+ sae["W_enc"], sae["b_enc"],
89
+ sae["W_dec"], sae["b_dec"]
90
+ )
91
+
92
+ # Get activations from FunctionGemma and encode
93
+ from transformers import AutoModelForCausalLM, AutoTokenizer
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ "google/functiongemma-270m-it",
97
+ torch_dtype=torch.bfloat16,
98
+ device_map="cuda"
99
+ )
100
+ tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it")
101
+
102
+ inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device)
103
+
104
+ # Hook to get activations
105
+ acts = []
106
+ def hook(module, inp, out):
107
+ acts.append(out[0].detach().float())
108
+ handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
109
+ with torch.no_grad():
110
+ _ = model(**inputs)
111
+ handle.remove()
112
+
113
+ # Run through SAE
114
+ recon, features = sae_model(acts[0])
115
+ print(f"Active features: {(features > 0).sum().item()}")
116
+ ```
117
+
118
+ ## Training Results
119
+
120
+ | Layer | Final Loss | Final MSE | L0 |
121
+ |-------|------------|-----------|-----|
122
+ | 0 | 3.4457 | 3.1244 | 1225 |
123
+ | 1 | 2.0052 | 1.9042 | 1386 |
124
+ | 2 | 0.1182 | 0.0759 | 1546 |
125
+ | 3 | 0.1182 | 0.0758 | 3096 |
126
+ | 4 | 0.0361 | 0.0170 | 1635 |
127
+ | 5 | 0.0414 | 0.0351 | 399 |
128
+ | 6 | 0.0318 | 0.0138 | 1807 |
129
+ | 7 | 0.0877 | 0.0661 | 1120 |
130
+ | 8 | 0.0733 | 0.0445 | 1379 |
131
+ | 9 | 0.0561 | 0.0317 | 1569 |
132
+ | 10 | 0.0997 | 0.0852 | 591 |
133
+ | 11 | 0.0252 | 0.0097 | 3658 |
134
+ | 12 | 0.0565 | 0.0395 | 962 |
135
+ | 13 | 0.0924 | 0.0619 | 1403 |
136
+ | 14 | 0.2711 | 0.2504 | 709 |
137
+ | 15 | 0.1501 | 0.1062 | 1576 |
138
+ | 16 | 0.1670 | 0.1426 | 870 |
139
+ | 17 | 0.0385 | 0.0218 | 1470 |
140
+
141
+ ## License
142
+
143
+ Apache 2.0
sae_layer_00.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68841d7a1a9e9a4dc690c0b0bd5296d1e540dda18f273970df465049b1de11a1
3
+ size 20992760
sae_layer_01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3949f21bce6742ce35e5c258f855a7ae786fa8be7d4d8a4cec20784a589a989
3
+ size 20992760
sae_layer_02.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25608c0fed52bae9065bd35033428a010c0b79a00b16df3ec5c20132a98508be
3
+ size 20992760
sae_layer_03.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00999bab7630c4492cef40e47d6f2268ca8364fd3a6e3dd335c117239ac571ee
3
+ size 20992760
sae_layer_04.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a5d6402ded655b9e7fb62a8df33d45158af3ffef9e1e31c39c392f218425897
3
+ size 20992760
sae_layer_05.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13bda23f02320a65ee466b3f88c0d421e652377fa26cc61491a8587964b6f0ba
3
+ size 20992760
sae_layer_06.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ba06b5797ea7ce20522e8d04cbf93085db7333abf9839b8890d13b7971c0364
3
+ size 20992760
sae_layer_07.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d61060bda5f66d5d6c1d9d25900076fd0092b296412fe9b16485c5a7ff4846d
3
+ size 20992760
sae_layer_08.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cb29be94c5490ad1b3d0ac9746089ab4567e54b27b4beb3a5c98fa57ac9cec9
3
+ size 20992760
sae_layer_09.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02e4e68c3e040c75b2ce73ef5fb052326926e104722a157319f4905c54c68594
3
+ size 20992760
sae_layer_10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6037efd3b492a493858a1389f26c84a412cc37b1331ae6fbbdd6345bd83305f
3
+ size 20992760
sae_layer_11.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d1109e982c2239c84e70e260bbcf48638618256b4b542e38622a1f07016153f
3
+ size 20992760
sae_layer_12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62e542f0f548b6027e4a3660ae868f1c44ff96231da7f8b347646ca31dbaa5fc
3
+ size 20992760
sae_layer_13.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67d6bcf3fa7d02c8bff0a484484d6e0dcd460b9ff4184063143410eb475240e0
3
+ size 20992760
sae_layer_14.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d5f2adabdad0c21d54a4828d4d1aa2848a939db40a97b2e609f5a8b55f1473d
3
+ size 20992760
sae_layer_15.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa91a0170e4ddf1ab82d9d6a582ab2941458a64f4e81367dd95950dfd65ff238
3
+ size 20992760
sae_layer_16.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb51ce70a5bbf895f469f9c10e1cce34b434cd95852708e2d39b08ce70e07588
3
+ size 20992760
sae_layer_17.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8f923993944b6b37f2cd8b4e86c18b649d3f18556709feaeea3ab7a38791316
3
+ size 20992760