|
|
--- |
|
|
title: FunctionGemma 270M SAE |
|
|
language: en |
|
|
tags: |
|
|
- sparse-autoencoder |
|
|
- sae |
|
|
- interpretability |
|
|
- functiongemma |
|
|
- gemma |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# FunctionGemma 270M Sparse Autoencoders |
|
|
|
|
|
Sparse Autoencoders (SAEs) trained on all 18 layers of [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it). |
|
|
|
|
|
## Architecture |
|
|
|
|
|
- **Base Model**: google/functiongemma-270m-it |
|
|
- **Layers**: 18 (decoder-only) |
|
|
- **Hidden Size**: 640 |
|
|
- **SAE Dimension**: 4096 (6.4x expansion) |
|
|
- **Hook Point**: `self_attn.o_proj` (output projection of self-attention) |
|
|
|
|
|
## Training |
|
|
|
|
|
- **Epochs**: 5 per layer |
|
|
- **Batch Size**: 1 |
|
|
- **Learning Rate**: 1e-4 |
|
|
- **Optimizer**: AdamW |
|
|
- **Loss**: MSE + 0.01 * L1 regularization |
|
|
- **Activation Clipping**: [-10, 10] |
|
|
- **Gradient Clipping**: max_norm=1.0 |
|
|
|
|
|
## Checkpoints |
|
|
|
|
|
Each checkpoint contains: |
|
|
```python |
|
|
{ |
|
|
"model_name": "google/functiongemma-270m-it", |
|
|
"layer_idx": int, |
|
|
"d_in": 640, |
|
|
"d_sae": 4096, |
|
|
"W_enc": torch.Tensor, # (640, 4096) |
|
|
"b_enc": torch.Tensor, # (4096,) |
|
|
"W_dec": torch.Tensor, # (4096, 640) |
|
|
"b_dec": torch.Tensor, # (640,) |
|
|
"history": { |
|
|
"loss": [...], |
|
|
"mse": [...], |
|
|
"l0": [...] |
|
|
} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Load SAE for a specific layer |
|
|
layer_idx = 0 |
|
|
ckpt_path = hf_hub_download( |
|
|
"mindchain/functiongemma-270m-sae", |
|
|
f"sae_layer_{layer_idx:02d}.pt" |
|
|
) |
|
|
sae = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
# Use SAE |
|
|
class JumpReLUSAE(torch.nn.Module): |
|
|
def __init__(self, W_enc, b_enc, W_dec, b_dec): |
|
|
super().__init__() |
|
|
self.W_enc = torch.nn.Parameter(W_enc) |
|
|
self.b_enc = torch.nn.Parameter(b_enc) |
|
|
self.W_dec = torch.nn.Parameter(W_dec) |
|
|
self.b_dec = torch.nn.Parameter(b_dec) |
|
|
|
|
|
def forward(self, x): |
|
|
batch, seq, d_in = x.shape |
|
|
x_flat = x.view(-1, d_in) |
|
|
pre_act = x_flat @ self.W_enc + self.b_enc |
|
|
features = torch.relu(pre_act) |
|
|
recon = features @ self.W_dec + self.b_dec |
|
|
return recon.view(batch, seq, d_in), features.view(batch, seq, -1) |
|
|
|
|
|
sae_model = JumpReLUSAE( |
|
|
sae["W_enc"], sae["b_enc"], |
|
|
sae["W_dec"], sae["b_dec"] |
|
|
) |
|
|
|
|
|
# Get activations from FunctionGemma and encode |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"google/functiongemma-270m-it", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it") |
|
|
|
|
|
inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device) |
|
|
|
|
|
# Hook to get activations |
|
|
acts = [] |
|
|
def hook(module, inp, out): |
|
|
acts.append(out[0].detach().float()) |
|
|
handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook) |
|
|
with torch.no_grad(): |
|
|
_ = model(**inputs) |
|
|
handle.remove() |
|
|
|
|
|
# Run through SAE |
|
|
recon, features = sae_model(acts[0]) |
|
|
print(f"Active features: {(features > 0).sum().item()}") |
|
|
``` |
|
|
|
|
|
## Training Results |
|
|
|
|
|
| Layer | Final Loss | Final MSE | L0 | |
|
|
|-------|------------|-----------|-----| |
|
|
| 0 | 3.4457 | 3.1244 | 1225 | |
|
|
| 1 | 2.0052 | 1.9042 | 1386 | |
|
|
| 2 | 0.1182 | 0.0759 | 1546 | |
|
|
| 3 | 0.1182 | 0.0758 | 3096 | |
|
|
| 4 | 0.0361 | 0.0170 | 1635 | |
|
|
| 5 | 0.0414 | 0.0351 | 399 | |
|
|
| 6 | 0.0318 | 0.0138 | 1807 | |
|
|
| 7 | 0.0877 | 0.0661 | 1120 | |
|
|
| 8 | 0.0733 | 0.0445 | 1379 | |
|
|
| 9 | 0.0561 | 0.0317 | 1569 | |
|
|
| 10 | 0.0997 | 0.0852 | 591 | |
|
|
| 11 | 0.0252 | 0.0097 | 3658 | |
|
|
| 12 | 0.0565 | 0.0395 | 962 | |
|
|
| 13 | 0.0924 | 0.0619 | 1403 | |
|
|
| 14 | 0.2711 | 0.2504 | 709 | |
|
|
| 15 | 0.1501 | 0.1062 | 1576 | |
|
|
| 16 | 0.1670 | 0.1426 | 870 | |
|
|
| 17 | 0.0385 | 0.0218 | 1470 | |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|