--- title: FunctionGemma 270M SAE language: en tags: - sparse-autoencoder - sae - interpretability - functiongemma - gemma license: apache-2.0 --- # FunctionGemma 270M Sparse Autoencoders Sparse Autoencoders (SAEs) trained on all 18 layers of [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it). ## Architecture - **Base Model**: google/functiongemma-270m-it - **Layers**: 18 (decoder-only) - **Hidden Size**: 640 - **SAE Dimension**: 4096 (6.4x expansion) - **Hook Point**: `self_attn.o_proj` (output projection of self-attention) ## Training - **Epochs**: 5 per layer - **Batch Size**: 1 - **Learning Rate**: 1e-4 - **Optimizer**: AdamW - **Loss**: MSE + 0.01 * L1 regularization - **Activation Clipping**: [-10, 10] - **Gradient Clipping**: max_norm=1.0 ## Checkpoints Each checkpoint contains: ```python { "model_name": "google/functiongemma-270m-it", "layer_idx": int, "d_in": 640, "d_sae": 4096, "W_enc": torch.Tensor, # (640, 4096) "b_enc": torch.Tensor, # (4096,) "W_dec": torch.Tensor, # (4096, 640) "b_dec": torch.Tensor, # (640,) "history": { "loss": [...], "mse": [...], "l0": [...] } } ``` ## Usage ```python import torch from huggingface_hub import hf_hub_download # Load SAE for a specific layer layer_idx = 0 ckpt_path = hf_hub_download( "mindchain/functiongemma-270m-sae", f"sae_layer_{layer_idx:02d}.pt" ) sae = torch.load(ckpt_path, map_location="cpu") # Use SAE class JumpReLUSAE(torch.nn.Module): def __init__(self, W_enc, b_enc, W_dec, b_dec): super().__init__() self.W_enc = torch.nn.Parameter(W_enc) self.b_enc = torch.nn.Parameter(b_enc) self.W_dec = torch.nn.Parameter(W_dec) self.b_dec = torch.nn.Parameter(b_dec) def forward(self, x): batch, seq, d_in = x.shape x_flat = x.view(-1, d_in) pre_act = x_flat @ self.W_enc + self.b_enc features = torch.relu(pre_act) recon = features @ self.W_dec + self.b_dec return recon.view(batch, seq, d_in), features.view(batch, seq, -1) sae_model = JumpReLUSAE( sae["W_enc"], sae["b_enc"], sae["W_dec"], sae["b_dec"] ) # Get activations from FunctionGemma and encode from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "google/functiongemma-270m-it", torch_dtype=torch.bfloat16, device_map="cuda" ) tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it") inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device) # Hook to get activations acts = [] def hook(module, inp, out): acts.append(out[0].detach().float()) handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook) with torch.no_grad(): _ = model(**inputs) handle.remove() # Run through SAE recon, features = sae_model(acts[0]) print(f"Active features: {(features > 0).sum().item()}") ``` ## Training Results | Layer | Final Loss | Final MSE | L0 | |-------|------------|-----------|-----| | 0 | 3.4457 | 3.1244 | 1225 | | 1 | 2.0052 | 1.9042 | 1386 | | 2 | 0.1182 | 0.0759 | 1546 | | 3 | 0.1182 | 0.0758 | 3096 | | 4 | 0.0361 | 0.0170 | 1635 | | 5 | 0.0414 | 0.0351 | 399 | | 6 | 0.0318 | 0.0138 | 1807 | | 7 | 0.0877 | 0.0661 | 1120 | | 8 | 0.0733 | 0.0445 | 1379 | | 9 | 0.0561 | 0.0317 | 1569 | | 10 | 0.0997 | 0.0852 | 591 | | 11 | 0.0252 | 0.0097 | 3658 | | 12 | 0.0565 | 0.0395 | 962 | | 13 | 0.0924 | 0.0619 | 1403 | | 14 | 0.2711 | 0.2504 | 709 | | 15 | 0.1501 | 0.1062 | 1576 | | 16 | 0.1670 | 0.1426 | 870 | | 17 | 0.0385 | 0.0218 | 1470 | ## License Apache 2.0