metadata
license: mit
library_name: transformers
base_model: google/t5gemma-2-270m-270m
colab:
- link: >-
https://colab.research.google.com/github/haddock-development/t5gemma2-sae-all-layers/blob/main/colab.ipynb
title: T5Gemma 2 SAE - Quick Start
model_index:
- name: google/t5gemma-2-270m-270m
results:
- task:
type: mechanistic-interpretability
dataset:
type: custom
name: T5Gemma 2 SAE
metrics:
- name: SAE Checkpoints
value: 36
tags:
- sae
- sparse-autoencoder
- t5gemma
- t5gemma2
- mechanistic-interpretability
- activation-steering
- steering
- neuronpedia
- gemma-scope
- sae-lens
- llm-interpretability
- explainable-ai
- xai
- model-steering
- feature-engineering
- representation-learning
T5Gemma 2 Sparse Autoencoders (All 36 Layers)
Sparse Autoencoders (SAEs) trained on all 36 layers of google/t5gemma-2-270m-270m for mechanistic interpretability and activation steering.
Quick Start
from huggingface_hub import hf_hub_download
import torch
# Load SAE
sae_path = hf_hub_download(
repo_id="mindchain/t5gemma2-sae-all-layers",
filename="encoder/sae_encoder_00.pt"
)
sae = torch.load(sae_path)
# Forward pass
activations = ... # Your activations
features = torch.relu(activations.float() @ sae['W_enc'] + sae['b_enc'])
reconstruction = features @ sae['W_dec'] + sae['b_dec']
Model Specifications
| Property | Value |
|---|---|
| Base Model | google/t5gemma-2-270m-270m |
| Architecture | T5 Text-to-Text (Encoder-Decoder) |
| Total Parameters | ~540M (270M encoder + 270M decoder) |
| Hidden Size (d_model) | 640 |
| FFN Dimension | 2,560 |
| Attention Heads | 10 |
| Vocabulary Size | 32,128 |
| Encoder Layers | 18 |
| Decoder Layers | 18 |
SAE Configuration
| Property | Value |
|---|---|
| SAE Input Dimension (d_in) | 640 |
| SAE Hidden Dimension (d_sae) | 4,096 |
| Expansion Factor | 6.4x |
| L1 Coefficient | 0.01 |
| Training Epochs | 5 |
| Batch Size | 2 |
| Learning Rate | 1e-4 |
| Optimizer | AdamW |
| Hook Point | self_attn.o_proj |
| Precision (Model) | float16 |
| Precision (SAE) | float32 |
Coverage (36 SAEs)
| Component | Layers | Files |
|---|---|---|
| Encoder | 0-17 | encoder/sae_encoder_00.pt - sae_encoder_17.pt |
| Decoder | 0-17 | decoder/sae_decoder_00.pt - sae_decoder_17.pt |
| Total | 36 | 36 checkpoint files |
Use Cases
- Mechanistic Interpretability - Understand what features represent
- Activation Steering - Modify model behavior by steering features
- Feature Analysis - Find concept-specific features
- Model Interventions - Ablate or enhance specific capabilities
Activation Steering
class SteeringHook:
def __init__(self, sae, feature_idx, strength):
self.sae = sae
self.feature_idx = feature_idx
self.strength = strength
def __call__(self, module, input, output):
acts = output[0] if isinstance(output, tuple) else output
acts_f32 = acts.float()
features = torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])
features[:, :, self.feature_idx] *= (1 + self.strength)
modified = features @ sae['W_dec'] + sae['b_dec']
if isinstance(output, tuple):
output[0].data = modified.to(output[0].dtype)
else:
output.data = modified.to(output.dtype)
# Install hook
hook = SteeringHook(sae, feature_idx=123, strength=0.5)
handle = model.model.encoder.layers[0].self_attn.o_proj.register_forward_hook(hook)
Checkpoint Structure
{
'model_name': 'google/t5gemma-2-270m-270m',
'layer_type': 'encoder' or 'decoder',
'layer_idx': 0-17,
'd_in': 640,
'd_sae': 4096,
'W_enc': Tensor([640, 4096]),
'b_enc': Tensor([4096]),
'W_dec': Tensor([4096, 640]),
'b_dec': Tensor([640]),
'history': {'loss': [...], 'l0': [...]}
}
Training Details
- Dataset: 1500 diverse text samples, 5 epochs
- Hook Point: self_attn.o_proj (attention output)
- Final Loss: ~0.0014
- Final L0: ~1367/4096 (33% sparsity)
- MSE: 0.105
- Cosine Similarity: 0.80
Related Work
- Gemma Scope - Sparse autoencoders for Gemma
- Neuronpedia - Interactive SAE visualization
- SAELens - SAE training tools
- TransformerLens - Interpretability library
License
MIT License
Credits
Trained by mindchain
Training Date: December 2025
Keywords: SAE, sparse autoencoder, T5, T5Gemma, interpretability, mechanistic interpretability, activation steering, feature visualization, Neuronpedia, Gemma Scope, explainable AI, XAI, model steering, feature engineering, representation learning, dictionary learning, layer interpretation, hidden state analysis, NLP, natural language processing, text-to-text, language model, LLM, large language model, LLM interpretability, encoder-decoder, transformer, neural network interpretation, sparse representations, features, embeddings