Upload BatchTopKSAE
Browse files- config.json +4 -0
- config.py +177 -0
- sae.py +390 -0
config.json
CHANGED
|
@@ -3,6 +3,10 @@
|
|
| 3 |
"architectures": [
|
| 4 |
"BatchTopKSAE"
|
| 5 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"aux_penalty": 0.03125,
|
| 7 |
"bandwidth": 0.001,
|
| 8 |
"dict_size": 128,
|
|
|
|
| 3 |
"architectures": [
|
| 4 |
"BatchTopKSAE"
|
| 5 |
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "config.SAEConfig",
|
| 8 |
+
"AutoModel": "sae.BatchTopKSAE"
|
| 9 |
+
},
|
| 10 |
"aux_penalty": 0.03125,
|
| 11 |
"bandwidth": 0.001,
|
| 12 |
"dict_size": 128,
|
config.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Optional, Literal
|
| 3 |
+
import torch
|
| 4 |
+
import pyrallis
|
| 5 |
+
from transformers import PretrainedConfig
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class TrainingConfig:
|
| 12 |
+
# Model settings
|
| 13 |
+
model_name: str = "unsloth/Meta-Llama-3.1-8B"
|
| 14 |
+
layer: int = 12
|
| 15 |
+
hook_point: str = "resid_mid"
|
| 16 |
+
act_size: Optional[int] = None # Will be set after model initialization
|
| 17 |
+
|
| 18 |
+
# SAE settings
|
| 19 |
+
sae_type: str = "batchtopk"
|
| 20 |
+
dict_size: int = 2**15
|
| 21 |
+
aux_penalty: float = 1/32
|
| 22 |
+
input_unit_norm: bool = True
|
| 23 |
+
|
| 24 |
+
# TopK specific settings
|
| 25 |
+
top_k: int = 50
|
| 26 |
+
top_k_warmup_steps_fraction: float = 0.1
|
| 27 |
+
start_top_k: int = 4096
|
| 28 |
+
top_k_aux: int = 512
|
| 29 |
+
|
| 30 |
+
n_batches_to_dead: int = 10
|
| 31 |
+
|
| 32 |
+
# Training settings
|
| 33 |
+
lr: float = 3e-4
|
| 34 |
+
bandwidth: float = 0.001
|
| 35 |
+
l1_coeff: float = 0.0018
|
| 36 |
+
num_tokens: int = int(1e9)
|
| 37 |
+
seq_len: int = 1024
|
| 38 |
+
model_batch_size: int = 16
|
| 39 |
+
num_batches_in_buffer: int = 5
|
| 40 |
+
max_grad_norm: float = 1.0
|
| 41 |
+
batch_size: int = 8192
|
| 42 |
+
|
| 43 |
+
# scheduler
|
| 44 |
+
warmup_fraction: float = 0.1
|
| 45 |
+
scheduler_type: str = 'linear'
|
| 46 |
+
|
| 47 |
+
# Hardware settings
|
| 48 |
+
device: str = "cuda"
|
| 49 |
+
dtype: torch.dtype = field(default=torch.float32)
|
| 50 |
+
sae_dtype: torch.dtype = field(default=torch.float32)
|
| 51 |
+
|
| 52 |
+
# Dataset settings
|
| 53 |
+
dataset_path: str = "cerebras/SlimPajama-627B"
|
| 54 |
+
|
| 55 |
+
# Logging settings
|
| 56 |
+
wandb_project: str = "turbo-llama-lens"
|
| 57 |
+
|
| 58 |
+
performance_log_steps: int = 100
|
| 59 |
+
save_checkpoint_steps: int = 10_000
|
| 60 |
+
def __post_init__(self):
|
| 61 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
| 62 |
+
print("CUDA not available, falling back to CPU")
|
| 63 |
+
self.device = "cpu"
|
| 64 |
+
|
| 65 |
+
# Convert string dtype to torch.dtype if needed
|
| 66 |
+
if isinstance(self.dtype, str):
|
| 67 |
+
self.dtype = getattr(torch, self.dtype)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SAEConfig(PretrainedConfig):
|
| 71 |
+
model_type = "sae"
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
# SAE architecture
|
| 76 |
+
act_size: int = None,
|
| 77 |
+
dict_size: int = 2**15,
|
| 78 |
+
sae_type: str = "batchtopk",
|
| 79 |
+
input_unit_norm: bool = True,
|
| 80 |
+
|
| 81 |
+
# TopK specific settings
|
| 82 |
+
top_k: int = 50,
|
| 83 |
+
top_k_aux: int = 512,
|
| 84 |
+
n_batches_to_dead: int = 10,
|
| 85 |
+
|
| 86 |
+
# Training hyperparameters
|
| 87 |
+
aux_penalty: float = 1/32,
|
| 88 |
+
l1_coeff: float = 0.0018,
|
| 89 |
+
bandwidth: float = 0.001,
|
| 90 |
+
|
| 91 |
+
# Hardware settings
|
| 92 |
+
dtype: str = "float32",
|
| 93 |
+
sae_dtype: str = "float32",
|
| 94 |
+
|
| 95 |
+
# Optional parent model info
|
| 96 |
+
parent_model_name: Optional[str] = None,
|
| 97 |
+
parent_layer: Optional[int] = None,
|
| 98 |
+
parent_hook_point: Optional[str] = None,
|
| 99 |
+
|
| 100 |
+
**kwargs
|
| 101 |
+
):
|
| 102 |
+
super().__init__(**kwargs)
|
| 103 |
+
self.act_size = act_size
|
| 104 |
+
self.dict_size = dict_size
|
| 105 |
+
self.sae_type = sae_type
|
| 106 |
+
self.input_unit_norm = input_unit_norm
|
| 107 |
+
|
| 108 |
+
self.top_k = top_k
|
| 109 |
+
self.top_k_aux = top_k_aux
|
| 110 |
+
self.n_batches_to_dead = n_batches_to_dead
|
| 111 |
+
|
| 112 |
+
self.aux_penalty = aux_penalty
|
| 113 |
+
self.l1_coeff = l1_coeff
|
| 114 |
+
self.bandwidth = bandwidth
|
| 115 |
+
|
| 116 |
+
self.dtype = dtype
|
| 117 |
+
self.sae_dtype = sae_dtype
|
| 118 |
+
|
| 119 |
+
self.parent_model_name = parent_model_name
|
| 120 |
+
self.parent_layer = parent_layer
|
| 121 |
+
self.parent_hook_point = parent_hook_point
|
| 122 |
+
|
| 123 |
+
def get_torch_dtype(self, dtype_str: str) -> torch.dtype:
|
| 124 |
+
dtype_map = {
|
| 125 |
+
"float32": torch.float32,
|
| 126 |
+
"float16": torch.float16,
|
| 127 |
+
"bfloat16": torch.bfloat16,
|
| 128 |
+
}
|
| 129 |
+
return dtype_map.get(dtype_str, torch.float32)
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def from_training_config(cls, cfg: TrainingConfig):
|
| 133 |
+
"""Convert TrainingConfig to SAEConfig"""
|
| 134 |
+
return cls(
|
| 135 |
+
act_size=cfg.act_size,
|
| 136 |
+
dict_size=cfg.dict_size,
|
| 137 |
+
sae_type=cfg.sae_type,
|
| 138 |
+
input_unit_norm=cfg.input_unit_norm,
|
| 139 |
+
top_k=cfg.top_k,
|
| 140 |
+
top_k_aux=cfg.top_k_aux,
|
| 141 |
+
n_batches_to_dead=cfg.n_batches_to_dead,
|
| 142 |
+
aux_penalty=cfg.aux_penalty,
|
| 143 |
+
l1_coeff=cfg.l1_coeff,
|
| 144 |
+
bandwidth=cfg.bandwidth,
|
| 145 |
+
dtype=str(cfg.dtype).split('.')[-1],
|
| 146 |
+
sae_dtype=str(cfg.sae_dtype).split('.')[-1],
|
| 147 |
+
parent_model_name=cfg.model_name,
|
| 148 |
+
parent_layer=cfg.layer,
|
| 149 |
+
parent_hook_point=cfg.hook_point,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def to_training_config(self) -> TrainingConfig:
|
| 153 |
+
"""Convert SAEConfig back to TrainingConfig"""
|
| 154 |
+
config_dict = asdict(self)
|
| 155 |
+
config_dict['dtype'] = self.get_torch_dtype(self.dtype)
|
| 156 |
+
config_dict['sae_dtype'] = self.get_torch_dtype(self.sae_dtype)
|
| 157 |
+
config_dict['model_name'] = self.parent_model_name
|
| 158 |
+
config_dict['layer'] = self.parent_layer
|
| 159 |
+
config_dict['hook_point'] = self.parent_hook_point
|
| 160 |
+
return TrainingConfig(**config_dict)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@pyrallis.wrap()
|
| 164 |
+
def get_config() -> TrainingConfig:
|
| 165 |
+
return TrainingConfig()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# For backward compatibility
|
| 169 |
+
def get_default_cfg() -> TrainingConfig:
|
| 170 |
+
return get_config()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def post_init_cfg(cfg: TrainingConfig) -> TrainingConfig:
|
| 174 |
+
"""
|
| 175 |
+
Any additional configuration setup that needs to happen after model initialization
|
| 176 |
+
"""
|
| 177 |
+
return cfg
|
sae.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
from typing import Optional, Dict, Union
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.autograd as autograd
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from safetensors.torch import save_file, load_file
|
| 9 |
+
from sae.modeling.config import SAEConfig
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseSAE(PreTrainedModel):
|
| 14 |
+
"""Base class for autoencoder models."""
|
| 15 |
+
config_class = SAEConfig
|
| 16 |
+
base_model_prefix = "sae"
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: SAEConfig):
|
| 19 |
+
super().__init__(config)
|
| 20 |
+
print(config)
|
| 21 |
+
self.config = config
|
| 22 |
+
torch.manual_seed(42)
|
| 23 |
+
|
| 24 |
+
self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))
|
| 25 |
+
self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size))
|
| 26 |
+
self.W_enc = nn.Parameter(
|
| 27 |
+
torch.nn.init.kaiming_uniform_(
|
| 28 |
+
torch.empty(self.config.act_size, self.config.dict_size)
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
self.W_dec = nn.Parameter(
|
| 32 |
+
torch.nn.init.kaiming_uniform_(
|
| 33 |
+
torch.empty(self.config.dict_size, self.config.act_size)
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
self.W_dec.data[:] = self.W_enc.t().data
|
| 37 |
+
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
| 38 |
+
self.num_batches_not_active = torch.zeros((self.config.dict_size,))
|
| 39 |
+
|
| 40 |
+
self.to(self.config.get_torch_dtype(self.config.dtype))
|
| 41 |
+
|
| 42 |
+
def preprocess_input(self, x):
|
| 43 |
+
x = x.to(self.config.get_torch_dtype(self.config.sae_dtype))
|
| 44 |
+
if self.config.input_unit_norm:
|
| 45 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
| 46 |
+
x = x - x_mean
|
| 47 |
+
x_std = x.std(dim=-1, keepdim=True)
|
| 48 |
+
x = x / (x_std + 1e-5)
|
| 49 |
+
return x, x_mean, x_std
|
| 50 |
+
else:
|
| 51 |
+
return x, None, None
|
| 52 |
+
|
| 53 |
+
def postprocess_output(self, x_reconstruct, x_mean, x_std):
|
| 54 |
+
if self.config.input_unit_norm:
|
| 55 |
+
x_reconstruct = x_reconstruct * x_std + x_mean
|
| 56 |
+
return x_reconstruct
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def make_decoder_weights_and_grad_unit_norm(self):
|
| 60 |
+
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
| 61 |
+
W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
|
| 62 |
+
-1, keepdim=True
|
| 63 |
+
) * W_dec_normed
|
| 64 |
+
self.W_dec.grad -= W_dec_grad_proj
|
| 65 |
+
self.W_dec.data = W_dec_normed
|
| 66 |
+
|
| 67 |
+
def update_inactive_features(self, acts):
|
| 68 |
+
self.num_batches_not_active += (acts.sum(0) == 0).float()
|
| 69 |
+
self.num_batches_not_active[acts.sum(0) > 0] = 0
|
| 70 |
+
|
| 71 |
+
# @classmethod
|
| 72 |
+
# def from_pretrained(
|
| 73 |
+
# cls,
|
| 74 |
+
# pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 75 |
+
# *model_args,
|
| 76 |
+
# **kwargs
|
| 77 |
+
# ) -> "BaseSAE":
|
| 78 |
+
# config = kwargs.pop("config", None)
|
| 79 |
+
# if config is None:
|
| 80 |
+
# config = SAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 81 |
+
|
| 82 |
+
# model = cls(config)
|
| 83 |
+
# model.load_state_dict(
|
| 84 |
+
# load_file(os.path.join(pretrained_model_name_or_path, "model.safetensors"))
|
| 85 |
+
# )
|
| 86 |
+
# return model
|
| 87 |
+
|
| 88 |
+
# def save_pretrained(
|
| 89 |
+
# self,
|
| 90 |
+
# save_directory: Union[str, os.PathLike],
|
| 91 |
+
# **kwargs
|
| 92 |
+
# ):
|
| 93 |
+
# os.makedirs(save_directory, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
# # Save the config
|
| 96 |
+
# self.config.save_pretrained(save_directory)
|
| 97 |
+
|
| 98 |
+
# # Save the model weights
|
| 99 |
+
# save_file(
|
| 100 |
+
# self.state_dict(),
|
| 101 |
+
# os.path.join(save_directory, "model.safetensors")
|
| 102 |
+
# )
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class BatchTopKSAE(BaseSAE):
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
| 108 |
+
|
| 109 |
+
x_cent = x - self.b_dec
|
| 110 |
+
acts = F.relu(x_cent @ self.W_enc)
|
| 111 |
+
acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1)
|
| 112 |
+
acts_topk = (
|
| 113 |
+
torch.zeros_like(acts.flatten())
|
| 114 |
+
.scatter(-1, acts_topk.indices, acts_topk.values)
|
| 115 |
+
.reshape(acts.shape)
|
| 116 |
+
)
|
| 117 |
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
| 118 |
+
|
| 119 |
+
self.update_inactive_features(acts_topk)
|
| 120 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
| 121 |
+
return output
|
| 122 |
+
|
| 123 |
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
| 124 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
| 125 |
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
| 126 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
| 127 |
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
| 128 |
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
| 129 |
+
loss = l2_loss + aux_loss
|
| 130 |
+
num_dead_features = (
|
| 131 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
| 132 |
+
).sum()
|
| 133 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
| 134 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
| 135 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
| 136 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
| 137 |
+
output = {
|
| 138 |
+
"sae_out": sae_out,
|
| 139 |
+
"feature_acts": acts_topk,
|
| 140 |
+
"num_dead_features": num_dead_features,
|
| 141 |
+
"loss": loss,
|
| 142 |
+
"l1_loss": l1_loss,
|
| 143 |
+
"l2_loss": l2_loss,
|
| 144 |
+
"l0_norm": l0_norm,
|
| 145 |
+
"l1_norm": l1_norm,
|
| 146 |
+
"aux_loss": aux_loss,
|
| 147 |
+
"explained_variance": explained_variance,
|
| 148 |
+
"top_k": self.config.top_k
|
| 149 |
+
}
|
| 150 |
+
return output
|
| 151 |
+
|
| 152 |
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
| 153 |
+
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
|
| 154 |
+
if dead_features.sum() > 0:
|
| 155 |
+
residual = x.float() - x_reconstruct.float()
|
| 156 |
+
acts_topk_aux = torch.topk(
|
| 157 |
+
acts[:, dead_features],
|
| 158 |
+
min(self.config.top_k_aux, dead_features.sum()),
|
| 159 |
+
dim=-1,
|
| 160 |
+
)
|
| 161 |
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
| 162 |
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
| 163 |
+
)
|
| 164 |
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
| 165 |
+
l2_loss_aux = (
|
| 166 |
+
self.config.aux_penalty
|
| 167 |
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
| 168 |
+
)
|
| 169 |
+
return l2_loss_aux
|
| 170 |
+
else:
|
| 171 |
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class TopKSAE(BaseSAE):
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
| 177 |
+
|
| 178 |
+
x_cent = x - self.b_dec
|
| 179 |
+
acts = F.relu(x_cent @ self.W_enc)
|
| 180 |
+
acts_topk = torch.topk(acts, self.config.top_k, dim=-1)
|
| 181 |
+
acts_topk = torch.zeros_like(acts).scatter(
|
| 182 |
+
-1, acts_topk.indices, acts_topk.values
|
| 183 |
+
)
|
| 184 |
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
| 185 |
+
|
| 186 |
+
self.update_inactive_features(acts_topk)
|
| 187 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
| 188 |
+
return output
|
| 189 |
+
|
| 190 |
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
| 191 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
| 192 |
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
| 193 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
| 194 |
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
| 195 |
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
| 196 |
+
loss = l2_loss + l1_loss + aux_loss
|
| 197 |
+
num_dead_features = (
|
| 198 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
| 199 |
+
).sum()
|
| 200 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
| 201 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
| 202 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
| 203 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
| 204 |
+
output = {
|
| 205 |
+
"sae_out": sae_out,
|
| 206 |
+
"feature_acts": acts_topk,
|
| 207 |
+
"num_dead_features": num_dead_features,
|
| 208 |
+
"loss": loss,
|
| 209 |
+
"l1_loss": l1_loss,
|
| 210 |
+
"l2_loss": l2_loss,
|
| 211 |
+
"l0_norm": l0_norm,
|
| 212 |
+
"l1_norm": l1_norm,
|
| 213 |
+
"explained_variance": explained_variance,
|
| 214 |
+
"aux_loss": aux_loss,
|
| 215 |
+
}
|
| 216 |
+
return output
|
| 217 |
+
|
| 218 |
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
| 219 |
+
dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
|
| 220 |
+
if dead_features.sum() > 0:
|
| 221 |
+
residual = x.float() - x_reconstruct.float()
|
| 222 |
+
acts_topk_aux = torch.topk(
|
| 223 |
+
acts[:, dead_features],
|
| 224 |
+
min(self.config.top_k_aux, dead_features.sum()),
|
| 225 |
+
dim=-1,
|
| 226 |
+
)
|
| 227 |
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
| 228 |
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
| 229 |
+
)
|
| 230 |
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
| 231 |
+
l2_loss_aux = (
|
| 232 |
+
self.config.aux_penalty
|
| 233 |
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
| 234 |
+
)
|
| 235 |
+
return l2_loss_aux
|
| 236 |
+
else:
|
| 237 |
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class VanillaSAE(BaseSAE):
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
| 243 |
+
x_cent = x - self.b_dec
|
| 244 |
+
acts = F.relu(x_cent @ self.W_enc + self.b_enc)
|
| 245 |
+
x_reconstruct = acts @ self.W_dec + self.b_dec
|
| 246 |
+
self.update_inactive_features(acts)
|
| 247 |
+
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
|
| 248 |
+
return output
|
| 249 |
+
|
| 250 |
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
| 251 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
| 252 |
+
l1_norm = acts.float().abs().sum(-1).mean()
|
| 253 |
+
l1_loss = self.config.l1_coeff * l1_norm
|
| 254 |
+
l0_norm = (acts > 0).float().sum(-1).mean()
|
| 255 |
+
loss = l2_loss + l1_loss
|
| 256 |
+
num_dead_features = (
|
| 257 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
| 258 |
+
).sum()
|
| 259 |
+
|
| 260 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
| 261 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
| 262 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
| 263 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
| 264 |
+
output = {
|
| 265 |
+
"sae_out": sae_out,
|
| 266 |
+
"feature_acts": acts,
|
| 267 |
+
"num_dead_features": num_dead_features,
|
| 268 |
+
"loss": loss,
|
| 269 |
+
"l1_loss": l1_loss,
|
| 270 |
+
"l2_loss": l2_loss,
|
| 271 |
+
"l0_norm": l0_norm,
|
| 272 |
+
"l1_norm": l1_norm,
|
| 273 |
+
"explained_variance": explained_variance,
|
| 274 |
+
}
|
| 275 |
+
return output
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
import torch
|
| 279 |
+
import torch.nn as nn
|
| 280 |
+
|
| 281 |
+
class RectangleFunction(autograd.Function):
|
| 282 |
+
@staticmethod
|
| 283 |
+
def forward(ctx, x):
|
| 284 |
+
ctx.save_for_backward(x)
|
| 285 |
+
return ((x > -0.5) & (x < 0.5)).float()
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def backward(ctx, grad_output):
|
| 289 |
+
(x,) = ctx.saved_tensors
|
| 290 |
+
grad_input = grad_output.clone()
|
| 291 |
+
grad_input[(x <= -0.5) | (x >= 0.5)] = 0
|
| 292 |
+
return grad_input
|
| 293 |
+
|
| 294 |
+
class JumpReLUFunction(autograd.Function):
|
| 295 |
+
@staticmethod
|
| 296 |
+
def forward(ctx, x, log_threshold, bandwidth):
|
| 297 |
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
| 298 |
+
threshold = torch.exp(log_threshold)
|
| 299 |
+
return x * (x > threshold).float()
|
| 300 |
+
|
| 301 |
+
@staticmethod
|
| 302 |
+
def backward(ctx, grad_output):
|
| 303 |
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
| 304 |
+
bandwidth = bandwidth_tensor.item()
|
| 305 |
+
threshold = torch.exp(log_threshold)
|
| 306 |
+
x_grad = (x > threshold).float() * grad_output
|
| 307 |
+
threshold_grad = (
|
| 308 |
+
-(threshold / bandwidth)
|
| 309 |
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
| 310 |
+
* grad_output
|
| 311 |
+
)
|
| 312 |
+
return x_grad, threshold_grad, None # None for bandwidth
|
| 313 |
+
|
| 314 |
+
class JumpReLU(nn.Module):
|
| 315 |
+
def __init__(self, feature_size, bandwidth, device='cpu'):
|
| 316 |
+
super(JumpReLU, self).__init__()
|
| 317 |
+
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
|
| 318 |
+
self.bandwidth = bandwidth
|
| 319 |
+
|
| 320 |
+
def forward(self, x):
|
| 321 |
+
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
|
| 322 |
+
|
| 323 |
+
class StepFunction(autograd.Function):
|
| 324 |
+
@staticmethod
|
| 325 |
+
def forward(ctx, x, log_threshold, bandwidth):
|
| 326 |
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
| 327 |
+
threshold = torch.exp(log_threshold)
|
| 328 |
+
return (x > threshold).float()
|
| 329 |
+
|
| 330 |
+
@staticmethod
|
| 331 |
+
def backward(ctx, grad_output):
|
| 332 |
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
| 333 |
+
bandwidth = bandwidth_tensor.item()
|
| 334 |
+
threshold = torch.exp(log_threshold)
|
| 335 |
+
x_grad = torch.zeros_like(x)
|
| 336 |
+
threshold_grad = (
|
| 337 |
+
-(1.0 / bandwidth)
|
| 338 |
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
| 339 |
+
* grad_output
|
| 340 |
+
)
|
| 341 |
+
return x_grad, threshold_grad, None # None for bandwidth
|
| 342 |
+
|
| 343 |
+
class JumpReLUSAE(BaseSAE):
|
| 344 |
+
def __init__(self, config: SAEConfig):
|
| 345 |
+
super().__init__(config)
|
| 346 |
+
self.jumprelu = JumpReLU(
|
| 347 |
+
feature_size=config.dict_size,
|
| 348 |
+
bandwidth=config.bandwidth,
|
| 349 |
+
device=config.device if hasattr(config, 'device') else 'cpu'
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def forward(self, x, use_pre_enc_bias=False):
|
| 353 |
+
x, x_mean, x_std = self.preprocess_input(x)
|
| 354 |
+
if use_pre_enc_bias:
|
| 355 |
+
x = x - self.b_dec
|
| 356 |
+
pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
|
| 357 |
+
feature_magnitudes = self.jumprelu(pre_activations)
|
| 358 |
+
|
| 359 |
+
x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
|
| 360 |
+
|
| 361 |
+
return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
|
| 362 |
+
|
| 363 |
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
| 364 |
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
| 365 |
+
|
| 366 |
+
l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean()
|
| 367 |
+
l0_loss = self.config.l1_coeff * l0
|
| 368 |
+
l1_loss = l0_loss
|
| 369 |
+
|
| 370 |
+
loss = l2_loss + l1_loss
|
| 371 |
+
num_dead_features = (
|
| 372 |
+
self.num_batches_not_active > self.config.n_batches_to_dead
|
| 373 |
+
).sum()
|
| 374 |
+
|
| 375 |
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
| 376 |
+
per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
|
| 377 |
+
total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
|
| 378 |
+
explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
|
| 379 |
+
output = {
|
| 380 |
+
"sae_out": sae_out,
|
| 381 |
+
"feature_acts": acts,
|
| 382 |
+
"num_dead_features": num_dead_features,
|
| 383 |
+
"loss": loss,
|
| 384 |
+
"l1_loss": l1_loss,
|
| 385 |
+
"l2_loss": l2_loss,
|
| 386 |
+
"l0_norm": l0,
|
| 387 |
+
"l1_norm": l0,
|
| 388 |
+
"explained_variance": explained_variance,
|
| 389 |
+
}
|
| 390 |
+
return output
|