theapemachine commited on
Commit
82900ee
·
verified ·
1 Parent(s): ba8ab4a

Add cortex/steering_vector.py

Browse files
Files changed (1) hide show
  1. cortex/steering_vector.py +184 -0
cortex/steering_vector.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SteeringVector: Activation-space behavioral control.
3
+
4
+ Inspired by Representation Engineering (Zou et al. 2023) and LoRRA.
5
+
6
+ Architecture:
7
+ - Maintains a set of named "concept directions" in activation space
8
+ - Each direction is a vector in R^D extracted via contrastive activation pairs
9
+ - At inference time, directions are added to the residual stream with learnable weights
10
+ - Directions can be extracted, composed, and interpolated
11
+
12
+ Failure mode addressed:
13
+ - Behavioral inflexibility: models have fixed behaviors baked in during training.
14
+ Steering vectors allow runtime control without retraining.
15
+ - Safety/alignment: can steer toward/away from toxicity, bias, refusal behaviors
16
+ - Persona control: steer toward specific communication styles
17
+ - Truthfulness: steer toward directions associated with factual vs confabulated outputs
18
+
19
+ Injection point: RESIDUAL_STREAM
20
+ - Rationale: The residual stream is the "information highway" of the transformer.
21
+ Additive modifications here have the most direct effect on all downstream layers.
22
+ """
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from typing import Optional, Union, List, Dict, Tuple
28
+ from cortex.core import CortexModule, InjectionPoint
29
+
30
+
31
+ class SteeringVector(CortexModule):
32
+ """
33
+ Adds learned/extracted direction vectors to the residual stream.
34
+
35
+ Supports two modes:
36
+ 1. Extracted: Directions from contrastive activation analysis (RepE-style)
37
+ 2. Learned: Directions trained end-to-end from task data
38
+
39
+ Multiple named directions can be composed with individual weights.
40
+
41
+ Args:
42
+ hidden_dim: Model hidden dimension
43
+ num_directions: Number of independent steering directions
44
+ direction_names: Optional names for each direction
45
+ alpha_init: Initial steering strength (learnable)
46
+ normalize: Whether to L2-normalize directions
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ hidden_dim: int,
52
+ num_directions: int = 4,
53
+ direction_names: Optional[List[str]] = None,
54
+ alpha_init: float = 0.0,
55
+ normalize: bool = True,
56
+ target_layers: Union[List[int], str] = "middle",
57
+ ):
58
+ super().__init__(InjectionPoint.RESIDUAL_STREAM, target_layers)
59
+
60
+ self.hidden_dim = hidden_dim
61
+ self.num_directions = num_directions
62
+ self.normalize = normalize
63
+
64
+ if direction_names is None:
65
+ direction_names = [f"direction_{i}" for i in range(num_directions)]
66
+ self.direction_names = direction_names
67
+
68
+ # Learnable direction vectors
69
+ self.directions = nn.Parameter(torch.randn(num_directions, hidden_dim) * 0.02)
70
+
71
+ # Per-direction steering strength (learnable)
72
+ self.alphas = nn.Parameter(torch.full((num_directions,), alpha_init))
73
+
74
+ # Per-layer scaling factor
75
+ self.layer_scale = nn.Parameter(torch.ones(1))
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ layer_idx: int,
81
+ **kwargs
82
+ ) -> torch.Tensor:
83
+ """
84
+ Add weighted steering vectors to the residual stream.
85
+
86
+ h_new = h + layer_scale * Σ_i (alpha_i * direction_i)
87
+ """
88
+ if self.normalize:
89
+ directions = F.normalize(self.directions, dim=-1)
90
+ else:
91
+ directions = self.directions
92
+
93
+ weighted_dirs = (self.alphas.unsqueeze(-1) * directions).sum(dim=0) # [D]
94
+ weighted_dirs = self.layer_scale * weighted_dirs
95
+
96
+ return hidden_states + weighted_dirs.unsqueeze(0).unsqueeze(0)
97
+
98
+ def set_direction(self, name_or_idx: Union[str, int], direction: torch.Tensor, alpha: float = 1.0):
99
+ """
100
+ Set a steering direction from an externally computed vector.
101
+
102
+ Args:
103
+ name_or_idx: Direction name or index
104
+ direction: Direction vector [hidden_dim]
105
+ alpha: Steering strength
106
+ """
107
+ if isinstance(name_or_idx, str):
108
+ idx = self.direction_names.index(name_or_idx)
109
+ else:
110
+ idx = name_or_idx
111
+
112
+ with torch.no_grad():
113
+ self.directions.data[idx] = direction
114
+ self.alphas.data[idx] = alpha
115
+
116
+ @staticmethod
117
+ def extract_direction(
118
+ model: nn.Module,
119
+ positive_prompts: List[str],
120
+ negative_prompts: List[str],
121
+ tokenizer,
122
+ layer_idx: int,
123
+ device: str = "cuda"
124
+ ) -> torch.Tensor:
125
+ """
126
+ Extract a steering direction via contrastive activation analysis.
127
+
128
+ Following RepE (Zou et al. 2023):
129
+ 1. Run positive prompts through the model, collect last-token activations at layer_idx
130
+ 2. Run negative prompts through the model, collect last-token activations
131
+ 3. Compute the difference: direction = mean(positive) - mean(negative)
132
+ 4. Optionally refine via PCA on the contrastive pairs
133
+
134
+ Args:
135
+ model: The LLM
136
+ positive_prompts: Prompts exemplifying the desired behavior
137
+ negative_prompts: Prompts exemplifying the undesired behavior
138
+ tokenizer: Model's tokenizer
139
+ layer_idx: Which layer to extract from
140
+ device: Device
141
+
142
+ Returns:
143
+ direction: [hidden_dim] steering direction vector
144
+ """
145
+ model.eval()
146
+
147
+ def get_activations(prompts):
148
+ activations = []
149
+ for prompt in prompts:
150
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
151
+ with torch.no_grad():
152
+ outputs = model(**inputs, output_hidden_states=True)
153
+ hidden = outputs.hidden_states[layer_idx] # [1, T, D]
154
+ last_token = hidden[:, -1, :] # [1, D]
155
+ activations.append(last_token)
156
+ return torch.cat(activations, dim=0) # [N, D]
157
+
158
+ pos_acts = get_activations(positive_prompts)
159
+ neg_acts = get_activations(negative_prompts)
160
+
161
+ direction = pos_acts.mean(dim=0) - neg_acts.mean(dim=0)
162
+
163
+ # PCA refinement for robust direction extraction
164
+ if len(positive_prompts) >= 4:
165
+ diffs = pos_acts - neg_acts # [N, D]
166
+ diffs = diffs - diffs.mean(dim=0) # Center
167
+ U, S, Vt = torch.linalg.svd(diffs, full_matrices=False)
168
+ direction = Vt[0] # First principal component
169
+
170
+ return direction.detach()
171
+
172
+ def get_direction_info(self) -> Dict[str, Tuple[float, torch.Tensor]]:
173
+ """Get all direction names, their alphas, and norms."""
174
+ info = {}
175
+ for i, name in enumerate(self.direction_names):
176
+ info[name] = {
177
+ "alpha": self.alphas[i].item(),
178
+ "norm": self.directions[i].norm().item(),
179
+ }
180
+ return info
181
+
182
+ def extra_repr(self):
183
+ return (f"hidden_dim={self.hidden_dim}, num_directions={self.num_directions}, "
184
+ f"names={self.direction_names}, {super().extra_repr()}")