File size: 3,679 Bytes
c8cdad1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
---
title: FunctionGemma 270M SAE
language: en
tags:
- sparse-autoencoder
- sae
- interpretability
- functiongemma
- gemma
license: apache-2.0
---
# FunctionGemma 270M Sparse Autoencoders
Sparse Autoencoders (SAEs) trained on all 18 layers of [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it).
## Architecture
- **Base Model**: google/functiongemma-270m-it
- **Layers**: 18 (decoder-only)
- **Hidden Size**: 640
- **SAE Dimension**: 4096 (6.4x expansion)
- **Hook Point**: `self_attn.o_proj` (output projection of self-attention)
## Training
- **Epochs**: 5 per layer
- **Batch Size**: 1
- **Learning Rate**: 1e-4
- **Optimizer**: AdamW
- **Loss**: MSE + 0.01 * L1 regularization
- **Activation Clipping**: [-10, 10]
- **Gradient Clipping**: max_norm=1.0
## Checkpoints
Each checkpoint contains:
```python
{
"model_name": "google/functiongemma-270m-it",
"layer_idx": int,
"d_in": 640,
"d_sae": 4096,
"W_enc": torch.Tensor, # (640, 4096)
"b_enc": torch.Tensor, # (4096,)
"W_dec": torch.Tensor, # (4096, 640)
"b_dec": torch.Tensor, # (640,)
"history": {
"loss": [...],
"mse": [...],
"l0": [...]
}
}
```
## Usage
```python
import torch
from huggingface_hub import hf_hub_download
# Load SAE for a specific layer
layer_idx = 0
ckpt_path = hf_hub_download(
"mindchain/functiongemma-270m-sae",
f"sae_layer_{layer_idx:02d}.pt"
)
sae = torch.load(ckpt_path, map_location="cpu")
# Use SAE
class JumpReLUSAE(torch.nn.Module):
def __init__(self, W_enc, b_enc, W_dec, b_dec):
super().__init__()
self.W_enc = torch.nn.Parameter(W_enc)
self.b_enc = torch.nn.Parameter(b_enc)
self.W_dec = torch.nn.Parameter(W_dec)
self.b_dec = torch.nn.Parameter(b_dec)
def forward(self, x):
batch, seq, d_in = x.shape
x_flat = x.view(-1, d_in)
pre_act = x_flat @ self.W_enc + self.b_enc
features = torch.relu(pre_act)
recon = features @ self.W_dec + self.b_dec
return recon.view(batch, seq, d_in), features.view(batch, seq, -1)
sae_model = JumpReLUSAE(
sae["W_enc"], sae["b_enc"],
sae["W_dec"], sae["b_dec"]
)
# Get activations from FunctionGemma and encode
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"google/functiongemma-270m-it",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it")
inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device)
# Hook to get activations
acts = []
def hook(module, inp, out):
acts.append(out[0].detach().float())
handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
with torch.no_grad():
_ = model(**inputs)
handle.remove()
# Run through SAE
recon, features = sae_model(acts[0])
print(f"Active features: {(features > 0).sum().item()}")
```
## Training Results
| Layer | Final Loss | Final MSE | L0 |
|-------|------------|-----------|-----|
| 0 | 3.4457 | 3.1244 | 1225 |
| 1 | 2.0052 | 1.9042 | 1386 |
| 2 | 0.1182 | 0.0759 | 1546 |
| 3 | 0.1182 | 0.0758 | 3096 |
| 4 | 0.0361 | 0.0170 | 1635 |
| 5 | 0.0414 | 0.0351 | 399 |
| 6 | 0.0318 | 0.0138 | 1807 |
| 7 | 0.0877 | 0.0661 | 1120 |
| 8 | 0.0733 | 0.0445 | 1379 |
| 9 | 0.0561 | 0.0317 | 1569 |
| 10 | 0.0997 | 0.0852 | 591 |
| 11 | 0.0252 | 0.0097 | 3658 |
| 12 | 0.0565 | 0.0395 | 962 |
| 13 | 0.0924 | 0.0619 | 1403 |
| 14 | 0.2711 | 0.2504 | 709 |
| 15 | 0.1501 | 0.1062 | 1576 |
| 16 | 0.1670 | 0.1426 | 870 |
| 17 | 0.0385 | 0.0218 | 1470 |
## License
Apache 2.0
|