mindchain's picture
Upload README.md with huggingface_hub
62a771b verified
metadata
license: mit
library_name: transformers
base_model: google/t5gemma-2-270m-270m
colab:
  - link: >-
      https://colab.research.google.com/github/haddock-development/t5gemma2-sae-all-layers/blob/main/colab.ipynb
    title: T5Gemma 2 SAE - Quick Start
model_index:
  - name: google/t5gemma-2-270m-270m
    results:
      - task:
          type: mechanistic-interpretability
        dataset:
          type: custom
          name: T5Gemma 2 SAE
        metrics:
          - name: SAE Checkpoints
            value: 36
tags:
  - sae
  - sparse-autoencoder
  - t5gemma
  - t5gemma2
  - mechanistic-interpretability
  - activation-steering
  - steering
  - neuronpedia
  - gemma-scope
  - sae-lens
  - llm-interpretability
  - explainable-ai
  - xai
  - model-steering
  - feature-engineering
  - representation-learning

T5Gemma 2 Sparse Autoencoders (All 36 Layers)

Sparse Autoencoders (SAEs) trained on all 36 layers of google/t5gemma-2-270m-270m for mechanistic interpretability and activation steering.

Open In Colab

Quick Start

from huggingface_hub import hf_hub_download
import torch

# Load SAE
sae_path = hf_hub_download(
    repo_id="mindchain/t5gemma2-sae-all-layers",
    filename="encoder/sae_encoder_00.pt"
)
sae = torch.load(sae_path)

# Forward pass
activations = ...  # Your activations
features = torch.relu(activations.float() @ sae['W_enc'] + sae['b_enc'])
reconstruction = features @ sae['W_dec'] + sae['b_dec']

Model Specifications

Property Value
Base Model google/t5gemma-2-270m-270m
Architecture T5 Text-to-Text (Encoder-Decoder)
Total Parameters ~540M (270M encoder + 270M decoder)
Hidden Size (d_model) 640
FFN Dimension 2,560
Attention Heads 10
Vocabulary Size 32,128
Encoder Layers 18
Decoder Layers 18

SAE Configuration

Property Value
SAE Input Dimension (d_in) 640
SAE Hidden Dimension (d_sae) 4,096
Expansion Factor 6.4x
L1 Coefficient 0.01
Training Epochs 5
Batch Size 2
Learning Rate 1e-4
Optimizer AdamW
Hook Point self_attn.o_proj
Precision (Model) float16
Precision (SAE) float32

Coverage (36 SAEs)

Component Layers Files
Encoder 0-17 encoder/sae_encoder_00.pt - sae_encoder_17.pt
Decoder 0-17 decoder/sae_decoder_00.pt - sae_decoder_17.pt
Total 36 36 checkpoint files

Use Cases

  1. Mechanistic Interpretability - Understand what features represent
  2. Activation Steering - Modify model behavior by steering features
  3. Feature Analysis - Find concept-specific features
  4. Model Interventions - Ablate or enhance specific capabilities

Activation Steering

class SteeringHook:
    def __init__(self, sae, feature_idx, strength):
        self.sae = sae
        self.feature_idx = feature_idx
        self.strength = strength
    
    def __call__(self, module, input, output):
        acts = output[0] if isinstance(output, tuple) else output
        acts_f32 = acts.float()
        features = torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])
        features[:, :, self.feature_idx] *= (1 + self.strength)
        modified = features @ sae['W_dec'] + sae['b_dec']
        
        if isinstance(output, tuple):
            output[0].data = modified.to(output[0].dtype)
        else:
            output.data = modified.to(output.dtype)

# Install hook
hook = SteeringHook(sae, feature_idx=123, strength=0.5)
handle = model.model.encoder.layers[0].self_attn.o_proj.register_forward_hook(hook)

Checkpoint Structure

{
    'model_name': 'google/t5gemma-2-270m-270m',
    'layer_type': 'encoder' or 'decoder',
    'layer_idx': 0-17,
    'd_in': 640,
    'd_sae': 4096,
    'W_enc': Tensor([640, 4096]),
    'b_enc': Tensor([4096]),
    'W_dec': Tensor([4096, 640]),
    'b_dec': Tensor([640]),
    'history': {'loss': [...], 'l0': [...]}
}

Training Details

  • Dataset: 1500 diverse text samples, 5 epochs
  • Hook Point: self_attn.o_proj (attention output)
  • Final Loss: ~0.0014
  • Final L0: ~1367/4096 (33% sparsity)
  • MSE: 0.105
  • Cosine Similarity: 0.80

Related Work

License

MIT License

Credits

Trained by mindchain
Training Date: December 2025


Keywords: SAE, sparse autoencoder, T5, T5Gemma, interpretability, mechanistic interpretability, activation steering, feature visualization, Neuronpedia, Gemma Scope, explainable AI, XAI, model steering, feature engineering, representation learning, dictionary learning, layer interpretation, hidden state analysis, NLP, natural language processing, text-to-text, language model, LLM, large language model, LLM interpretability, encoder-decoder, transformer, neural network interpretation, sparse representations, features, embeddings