igardner commited on
Commit
3d558ae
·
1 Parent(s): 22b6988

Added thing. Changed up project a little. Just fucking around basically

Browse files
Files changed (1) hide show
  1. concept_steerer.py +196 -0
concept_steerer.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # concept_steerer.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from typing import List, Dict, Optional, Tuple
5
+ import numpy as np
6
+
7
+ class ConceptSteerer:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "unsloth/Llama-3.2-1B-Instruct",
11
+ device: str = "auto"
12
+ ):
13
+ """
14
+ A robust class for performing activation steering on LLMs.
15
+
16
+ Args:
17
+ model_name: The Hugging Face model name.
18
+ device: The device to load the model on ("auto", "cuda", "cpu").
19
+ """
20
+ print(f"Loading model {model_name}...")
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ if self.tokenizer.pad_token is None:
23
+ self.tokenizer.pad_token = self.tokenizer.eos_token
24
+
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ device_map=device,
28
+ torch_dtype=torch.float16 if device != "cpu" else torch.float32,
29
+ attn_implementation="sdpa", # Use optimized attention
30
+ trust_remote_code=False
31
+ )
32
+ self.model.eval()
33
+ self.num_layers = len(self.model.model.layers)
34
+ self.concepts = {} # name -> steering vector
35
+
36
+ def _format_prompt_for_model(self, prompt: str) -> str:
37
+ """Format the prompt according to the model's chat template if available."""
38
+ if hasattr(self.tokenizer, 'apply_chat_template'):
39
+ messages = [{"role": "user", "content": prompt}]
40
+ return self.tokenizer.apply_chat_template(
41
+ messages, tokenize=False, add_generation_prompt=True
42
+ )
43
+ return prompt
44
+
45
+ def _get_mean_activation(self, prompts: List[str], layer: int, token_pos: int = -1) -> torch.Tensor:
46
+ """Get the mean activation for a set of prompts at a specific layer and token position."""
47
+ acts = []
48
+ for prompt in prompts:
49
+ formatted_prompt = self._format_prompt_for_model(prompt)
50
+ inputs = self.tokenizer(
51
+ formatted_prompt,
52
+ return_tensors="pt",
53
+ padding=True,
54
+ truncation=True,
55
+ max_length=512
56
+ ).to(self.model.device)
57
+
58
+ with torch.no_grad():
59
+ outputs = self.model(**inputs, output_hidden_states=True)
60
+
61
+ # Resolve token index
62
+ seq_len = inputs.input_ids.shape[1]
63
+ if token_pos >= 0:
64
+ idx = min(token_pos, seq_len - 1)
65
+ else:
66
+ idx = seq_len + token_pos
67
+
68
+ act = outputs.hidden_states[layer][0, idx, :].float().cpu()
69
+ acts.append(act)
70
+
71
+ return torch.stack(acts).mean(dim=0)
72
+
73
+ def register_concept(
74
+ self,
75
+ name: str,
76
+ positive_prompts: List[str],
77
+ negative_prompts: List[str],
78
+ layer: int = -1, # Default to last layer
79
+ token_pos: int = -1
80
+ ):
81
+ """Create and register a steering vector from contrastive examples."""
82
+ if layer < 0:
83
+ layer = self.num_layers + layer
84
+
85
+ pos_acts = self._get_mean_activation(positive_prompts, layer, token_pos)
86
+ neg_acts = self._get_mean_activation(negative_prompts, layer, token_pos)
87
+ steering_vec = (pos_acts - neg_acts)
88
+ # Normalize to unit vector for consistent scaling
89
+ self.concepts[name] = steering_vec / steering_vec.norm()
90
+
91
+ def steer_by_relation(
92
+ self,
93
+ name: str,
94
+ A: str, B: str, C: str, D: str,
95
+ layer: int = -1,
96
+ token_pos: int = -1,
97
+ num_examples: int = 5
98
+ ):
99
+ """
100
+ Create a composite concept using the relation (A is to B) as (C is to D).
101
+ Generates examples on-the-fly using the model itself.
102
+ """
103
+ if layer < 0:
104
+ layer = self.num_layers + layer
105
+
106
+ def generate_examples(seed_prompt: str, num: int) -> List[str]:
107
+ examples = []
108
+ for _ in range(num):
109
+ inputs = self.tokenizer(
110
+ self._format_prompt_for_model(seed_prompt),
111
+ return_tensors="pt"
112
+ ).to(self.model.device)
113
+ with torch.no_grad():
114
+ out = self.model.generate(
115
+ **inputs,
116
+ max_new_tokens=20,
117
+ do_sample=True,
118
+ temperature=0.7,
119
+ top_p=0.9,
120
+ pad_token_id=self.tokenizer.pad_token_id
121
+ )
122
+ full_text = self.tokenizer.decode(out[0], skip_special_tokens=True)
123
+ # Extract just the generated part
124
+ generated = full_text[len(seed_prompt):].strip()
125
+ examples.append(generated)
126
+ return examples
127
+
128
+ # Generate examples for each concept
129
+ pos_examples = generate_examples(f"{A} is to {B} as {C} is to", num_examples)
130
+ neg_examples = generate_examples(f"{A} is to {B} as {D} is to", num_examples)
131
+
132
+ # Create the composite vector: (A-B) + (C-D)
133
+ AB_vec = self._get_mean_activation([A], layer, -1) - self._get_mean_activation([B], layer, -1)
134
+ CD_vec = self._get_mean_activation([C], layer, -1) - self._get_mean_activation([D], layer, -1)
135
+ composite_vec = AB_vec + CD_vec
136
+ self.concepts[name] = composite_vec / composite_vec.norm()
137
+
138
+ def generate(
139
+ self,
140
+ prompt: str,
141
+ steering_config: Optional[Dict[str, float]] = None,
142
+ layer: int = -1,
143
+ token_pos: int = -1,
144
+ max_new_tokens: int = 100,
145
+ **gen_kwargs
146
+ ) -> str:
147
+ """Generate text with optional activation steering."""
148
+ if layer < 0:
149
+ layer = self.num_layers + layer
150
+
151
+ if steering_config is None:
152
+ steering_config = {}
153
+
154
+ inputs = self.tokenizer(
155
+ self._format_prompt_for_model(prompt),
156
+ return_tensors="pt"
157
+ ).to(self.model.device)
158
+
159
+ # Resolve token index for the hook
160
+ seq_len = inputs.input_ids.shape[1]
161
+ if token_pos >= 0:
162
+ hook_token_idx = min(token_pos, seq_len - 1)
163
+ else:
164
+ hook_token_idx = seq_len + token_pos
165
+
166
+ def hook_fn(module, input, output):
167
+ total_steer = torch.zeros_like(output[0][0, hook_token_idx, :])
168
+ for concept_name, strength in steering_config.items():
169
+ if concept_name in self.concepts:
170
+ vec = self.concepts[concept_name].to(output[0].device, dtype=output[0].dtype)
171
+ total_steer += vec * strength
172
+ output[0][0, hook_token_idx, :] += total_steer
173
+ return output
174
+
175
+ handle = self.model.model.layers[layer].register_forward_hook(hook_fn)
176
+ try:
177
+ with torch.no_grad():
178
+ out = self.model.generate(
179
+ **inputs,
180
+ max_new_tokens=max_new_tokens,
181
+ pad_token_id=self.tokenizer.pad_token_id,
182
+ do_sample=True,
183
+ temperature=0.6,
184
+ top_p=0.9,
185
+ **gen_kwargs
186
+ )
187
+ result = self.tokenizer.decode(out[0], skip_special_tokens=True)
188
+ # Remove the prompt from the result if it's a chat model
189
+ if result.startswith(prompt):
190
+ result = result[len(prompt):].strip()
191
+ return result
192
+ finally:
193
+ handle.remove()
194
+
195
+ def get_concept_names(self) -> List[str]:
196
+ return list(self.concepts.keys())