File size: 4,516 Bytes
4e4764b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Olmo2ForCausalLM
class SAE(nn.Module):
    def __init__(self, input_size, hidden_size, init_scale=0.1):
        super().__init__()
        
        # Store dimensions
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Initialize as before 
        self.encode = nn.Linear(input_size, hidden_size, bias=True)
        self.decode = nn.Linear(hidden_size, input_size, bias=True)
                
        with torch.no_grad():
            # Random directions
            decoder_weights = torch.randn(input_size, hidden_size) 
            # Normalize columns
            decoder_weights = decoder_weights / torch.linalg.vector_norm(decoder_weights, dim=0, keepdim=True)
            # Scale by random values between 0.05 and 1.0
            scales = torch.rand(hidden_size) * 0.95 + 0.05
            decoder_weights = decoder_weights * scales
            
            self.decode.weight.data = decoder_weights
            self.encode.weight.data = decoder_weights.T.contiguous()
            self.encode.bias.data.zero_() #zero in place
            self.decode.bias.data.zero_()

        self.constrain_weights()

    @property
    def device(self):
        """Return the device the model parameters are on"""
        return next(self.parameters()).device

    def constrain_weights(self):
        """Constrain the decoder weights to have unit norm."""
        with torch.no_grad():
            decoder_norm = torch.linalg.vector_norm(self.decode.weight, dim=0, keepdim=True)
            self.decode.weight.data = self.decode.weight.data / decoder_norm

    def forward(self, x):
        features = F.relu(self.encode(x))
        reconstruction = self.decode(features)
        return reconstruction, features

    def get_decoder_norms(self):
        # returns a 1-D tensor (hidden_size,) on the right device/dtype
        return torch.linalg.vector_norm(self.decode.weight, dim=0)

        
    @property
    def W_dec(self):
        """Return decoder weights for easier access during analysis"""
        return self.decode.weight
        
    def compute_loss(self, x, recon, feats, lambda_):
        # reconstruction term β€” sum over feature-dim, mean over batch
        recon_mse = (recon - x).pow(2).sum(-1).mean()

        # sparsity term β€” L1 on feature activations * current decoder-column norms
        sparsity = (feats.abs() * self.get_decoder_norms()).sum(1).mean()

        return recon_mse + lambda_ * sparsity

class SteerableOlmo2ForCausalLM(Olmo2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.steering_layer = None
        self.sae = None
        self.steering_features = {}
        self.steering_hook = None
        self.sae_max = None

    def set_sae_and_layer(self, sae, layer):
        self.sae = sae
        self.steering_layer = layer
        self._register_steering_hook()

    def set_sae_max(self, sae_max):
        self.sae_max = sae_max

    def set_steering(self, feature_idx, value, *, as_multiple_of_max=False):
        if as_multiple_of_max and self.sae_max is not None:
            value = float(value) * float(self.sae_max[feature_idx])
        self.steering_features[feature_idx] = value

    def clear_steering(self):
        self.steering_features = {}

    @torch.no_grad()
    def _steering_hook_fn(self, module, input, output):
        if not self.steering_features or self.sae is None:
            return output

        hidden_states = output[0]
        feats = self.sae.encode(hidden_states)
        recon = self.sae.decode(feats)
        error = hidden_states - recon

        feats_steered = feats.clone()
        for idx, clamp_value in self.steering_features.items():
            feats_steered[..., idx] = clamp_value

        recon_steered = self.sae.decode(feats_steered)
        hidden_steered = recon_steered + error

        return (hidden_steered,) + output[1:]

    def _register_steering_hook(self):
        if self.steering_hook is not None:
            self.steering_hook.remove()
            self.steering_hook = None

        if self.steering_layer is not None:
            target_layer = self.model.layers[self.steering_layer]
            self.steering_hook = target_layer.register_forward_hook(self._steering_hook_fn)

    def remove_steering_hook(self):
        if self.steering_hook is not None:
            self.steering_hook.remove()
            self.steering_hook = None