--- license: mit library_name: transformers base_model: google/t5gemma-2-270m-270m colab: - link: https://colab.research.google.com/github/haddock-development/t5gemma2-sae-all-layers/blob/main/colab.ipynb title: T5Gemma 2 SAE - Quick Start model_index: - name: google/t5gemma-2-270m-270m results: - task: type: mechanistic-interpretability dataset: type: custom name: T5Gemma 2 SAE metrics: - name: SAE Checkpoints value: 36 tags: - sae - sparse-autoencoder - t5gemma - t5gemma2 - mechanistic-interpretability - activation-steering - steering - neuronpedia - gemma-scope - sae-lens - llm-interpretability - explainable-ai - xai - model-steering - feature-engineering - representation-learning --- # T5Gemma 2 Sparse Autoencoders (All 36 Layers) **Sparse Autoencoders (SAEs)** trained on all 36 layers of `google/t5gemma-2-270m-270m` for mechanistic interpretability and activation steering. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/haddock-development/t5gemma2-sae-all-layers/blob/main/colab.ipynb) ## Quick Start ```python from huggingface_hub import hf_hub_download import torch # Load SAE sae_path = hf_hub_download( repo_id="mindchain/t5gemma2-sae-all-layers", filename="encoder/sae_encoder_00.pt" ) sae = torch.load(sae_path) # Forward pass activations = ... # Your activations features = torch.relu(activations.float() @ sae['W_enc'] + sae['b_enc']) reconstruction = features @ sae['W_dec'] + sae['b_dec'] ``` ## Model Specifications | Property | Value | |----------|-------| | Base Model | google/t5gemma-2-270m-270m | | Architecture | T5 Text-to-Text (Encoder-Decoder) | | Total Parameters | ~540M (270M encoder + 270M decoder) | | Hidden Size (d_model) | 640 | | FFN Dimension | 2,560 | | Attention Heads | 10 | | Vocabulary Size | 32,128 | | Encoder Layers | 18 | | Decoder Layers | 18 | ## SAE Configuration | Property | Value | |----------|-------| | SAE Input Dimension (d_in) | 640 | | SAE Hidden Dimension (d_sae) | 4,096 | | Expansion Factor | 6.4x | | L1 Coefficient | 0.01 | | Training Epochs | 5 | | Batch Size | 2 | | Learning Rate | 1e-4 | | Optimizer | AdamW | | Hook Point | self_attn.o_proj | | Precision (Model) | float16 | | Precision (SAE) | float32 | ## Coverage (36 SAEs) | Component | Layers | Files | |-----------|--------|-------| | Encoder | 0-17 | encoder/sae_encoder_00.pt - sae_encoder_17.pt | | Decoder | 0-17 | decoder/sae_decoder_00.pt - sae_decoder_17.pt | | **Total** | **36** | **36 checkpoint files** | ## Use Cases 1. **Mechanistic Interpretability** - Understand what features represent 2. **Activation Steering** - Modify model behavior by steering features 3. **Feature Analysis** - Find concept-specific features 4. **Model Interventions** - Ablate or enhance specific capabilities ## Activation Steering ```python class SteeringHook: def __init__(self, sae, feature_idx, strength): self.sae = sae self.feature_idx = feature_idx self.strength = strength def __call__(self, module, input, output): acts = output[0] if isinstance(output, tuple) else output acts_f32 = acts.float() features = torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc']) features[:, :, self.feature_idx] *= (1 + self.strength) modified = features @ sae['W_dec'] + sae['b_dec'] if isinstance(output, tuple): output[0].data = modified.to(output[0].dtype) else: output.data = modified.to(output.dtype) # Install hook hook = SteeringHook(sae, feature_idx=123, strength=0.5) handle = model.model.encoder.layers[0].self_attn.o_proj.register_forward_hook(hook) ``` ## Checkpoint Structure ```python { 'model_name': 'google/t5gemma-2-270m-270m', 'layer_type': 'encoder' or 'decoder', 'layer_idx': 0-17, 'd_in': 640, 'd_sae': 4096, 'W_enc': Tensor([640, 4096]), 'b_enc': Tensor([4096]), 'W_dec': Tensor([4096, 640]), 'b_dec': Tensor([640]), 'history': {'loss': [...], 'l0': [...]} } ``` ## Training Details - **Dataset**: 1500 diverse text samples, 5 epochs - **Hook Point**: self_attn.o_proj (attention output) - **Final Loss**: ~0.0014 - **Final L0**: ~1367/4096 (33% sparsity) - **MSE**: 0.105 - **Cosine Similarity**: 0.80 ## Related Work - [Gemma Scope](https://www.deepmind.com/google-research/blog/gemma-scope) - Sparse autoencoders for Gemma - [Neuronpedia](https://neuronpedia.org) - Interactive SAE visualization - [SAELens](https://github.com/decoderesearch/SAELens) - SAE training tools - [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) - Interpretability library ## License MIT License ## Credits Trained by [mindchain](https://huggingface.co/mindchain) Training Date: December 2025 --- **Keywords**: SAE, sparse autoencoder, T5, T5Gemma, interpretability, mechanistic interpretability, activation steering, feature visualization, Neuronpedia, Gemma Scope, explainable AI, XAI, model steering, feature engineering, representation learning, dictionary learning, layer interpretation, hidden state analysis, NLP, natural language processing, text-to-text, language model, LLM, large language model, LLM interpretability, encoder-decoder, transformer, neural network interpretation, sparse representations, features, embeddings