AmnaHassan commited on
Commit
b6d2a24
·
verified ·
1 Parent(s): 52792a5

Create patching_logic.py

Browse files
Files changed (1) hide show
  1. patching_logic.py +199 -0
patching_logic.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformer_lens import HookedTransformer
4
+ from datasets import load_dataset
5
+
6
+ # --- LOAD MODEL (Loaded only once when the module is imported) ---
7
+ model = HookedTransformer.from_pretrained("gpt2", device="cpu")
8
+
9
+ # ============================================================
10
+ # SAFE TRIMMING HELPERS FOR LONG INPUT
11
+ # ============================================================
12
+
13
+ def trim_to_context(text, keep_last_tokens=None):
14
+ """
15
+ Trim `text` so its tokenized length <= model.cfg.n_ctx.
16
+ Keeps the LAST tokens in the sequence (closest to the verb).
17
+ """
18
+ toks = model.tokenizer.encode(text)
19
+ n_ctx = model.cfg.n_ctx
20
+ max_keep = n_ctx if keep_last_tokens is None else min(keep_last_tokens, n_ctx)
21
+
22
+ if len(toks) > max_keep:
23
+ toks = toks[-max_keep:]
24
+ return model.tokenizer.decode(toks)
25
+ return text
26
+
27
+
28
+ def trim_pair_for_patching(good_sentence, bad_sentence):
29
+ """
30
+ Trim both good & bad sentences to same-size windows <= n_ctx.
31
+ Ensures cache shapes match when patching.
32
+ """
33
+ good_toks = model.tokenizer.encode(good_sentence)
34
+ bad_toks = model.tokenizer.encode(bad_sentence)
35
+ n_ctx = model.cfg.n_ctx
36
+
37
+ if len(good_toks) > n_ctx or len(bad_toks) > n_ctx:
38
+ good_toks = good_toks[-n_ctx:]
39
+ bad_toks = bad_toks[-n_ctx:]
40
+ return model.tokenizer.decode(good_toks), model.tokenizer.decode(bad_toks)
41
+ else:
42
+ return good_sentence, bad_sentence
43
+
44
+
45
+ # ============================================================
46
+ # UPDATED NEXT-TOKEN SCORING
47
+ # ============================================================
48
+
49
+ def score_next_token(sentence_prefix, token1, token2):
50
+ """
51
+ Safely compute p(token1), p(token2) after a prefix.
52
+ Automatically trims prefix to fit GPT-2 context.
53
+ """
54
+ max_prefix_toks = max(1, model.cfg.n_ctx - 1)
55
+ prefix_trimmed = trim_to_context(sentence_prefix, keep_last_tokens=max_prefix_toks)
56
+
57
+ tokens = model.to_tokens(prefix_trimmed)
58
+ logits = model(tokens)[0, -1]
59
+ probs = F.softmax(logits, dim=-1)
60
+
61
+ t1 = model.tokenizer.encode(" " + token1)[0]
62
+ t2 = model.tokenizer.encode(" " + token2)[0]
63
+
64
+ return float(probs[t1].detach().cpu().numpy()), float(probs[t2].detach().cpu().numpy())
65
+
66
+
67
+ # ============================================================
68
+ # DETECT VERB PAIR + BUILD BAD SENTENCE VARIANT
69
+ # ============================================================
70
+
71
+ def detect_and_build_variants(user_sentence):
72
+ verb_pairs = [
73
+ ("has", "have"),
74
+ ("is", "are"),
75
+ ("was", "were"),
76
+ ("does", "do"),
77
+ ]
78
+
79
+ s = user_sentence.strip()
80
+ s_lower = s.lower()
81
+
82
+ for singular, plural in verb_pairs:
83
+ if singular in s_lower or plural in s_lower:
84
+ # Identify actual & wrong verb
85
+ if singular in s_lower:
86
+ split_token = singular
87
+ actual = singular
88
+ wrong = plural
89
+ else:
90
+ split_token = plural
91
+ actual = plural
92
+ wrong = singular
93
+
94
+ idx = s_lower.rfind(split_token)
95
+ prefix = s[:idx] if idx != -1 else s.rsplit(split_token, 1)[0]
96
+ token_len = len(split_token)
97
+ remainder = s[idx + token_len:]
98
+
99
+ bad_sentence = (prefix + wrong + remainder).strip()
100
+
101
+ return prefix, actual, wrong, bad_sentence, (singular, plural)
102
+
103
+ return None, None, None, None, None
104
+
105
+
106
+ # ============================================================
107
+ # UPDATED ACTIVATION PATCHING (SAFE)
108
+ # ============================================================
109
+
110
+ def patch_layer_user(layer, good_sentence, bad_sentence, verb_pair):
111
+ """
112
+ Patch block.layer.attn.hook_z from good->bad.
113
+ Sentences are trimmed to same token length for safe patching.
114
+ """
115
+ # Tokenize both sentences
116
+ good_toks = model.tokenizer.encode(good_sentence)
117
+ bad_toks = model.tokenizer.encode(bad_sentence)
118
+
119
+ # Trim to shortest length
120
+ min_len = min(len(good_toks), len(bad_toks))
121
+
122
+ # We now operate on raw tokens, ensuring we don't accidentally introduce BOS
123
+ # or other tokenizer-specific issues that cause length mismatches.
124
+ # The `run_with_cache` and `run_with_hooks` will be called with prepend_bos=False
125
+ # to maintain this consistency.
126
+ good_trimmed_tokens = good_toks[-min_len:]
127
+ bad_trimmed_tokens = bad_toks[-min_len:]
128
+
129
+ good_trimmed_str = model.tokenizer.decode(good_trimmed_tokens)
130
+ bad_trimmed_str = model.tokenizer.decode(bad_trimmed_tokens)
131
+
132
+ # Get cache for good sentence. Explicitly set prepend_bos=False.
133
+ _, cache_good = model.run_with_cache(good_trimmed_str, prepend_bos=False)
134
+
135
+ # Patch only matching sequence length
136
+ def patch_hook(value, hook):
137
+ # Since prepend_bos=False was used for run_with_cache,
138
+ # cache_good[hook.name] will have the sequence length of good_trimmed_str (min_len).
139
+ # The 'value' tensor in the hook will also have sequence length min_len because
140
+ # run_with_hooks is also called with prepend_bos=False.
141
+ # Thus, direct replacement is safe.
142
+ return cache_good[hook.name]
143
+
144
+ # Run patched logits. Explicitly set prepend_bos=False.
145
+ patched_logits = model.run_with_hooks(
146
+ bad_trimmed_str,
147
+ fwd_hooks=[(f"blocks.{layer}.attn.hook_z", patch_hook)],
148
+ prepend_bos=False
149
+ )
150
+
151
+ probs = F.softmax(patched_logits[0, -1], dim=-1)
152
+ singular, plural = verb_pair
153
+ # Tokenize for next-token prediction, these are usually just single tokens
154
+ t_sing = model.tokenizer.encode(" " + singular)[0]
155
+ t_plur = model.tokenizer.encode(" " + plural)[0]
156
+
157
+ return float(probs[t_sing].detach().cpu().numpy()), float(probs[t_plur].detach().cpu().numpy())
158
+
159
+
160
+ # ============================================================
161
+ # FULL USER PIPELINE
162
+ # ============================================================
163
+
164
+ def run_user_activation_pipeline(user_sentence, n_layers_to_check=None):
165
+ prefix, actual, wrong, bad_sentence, verb_pair = detect_and_build_variants(user_sentence)
166
+
167
+ if prefix is None:
168
+ return {"error": "No supported verb pair found (has/have, is/are, was/were, does/do)."}
169
+
170
+ p_actual, p_wrong = score_next_token(prefix, actual, wrong)
171
+
172
+ singular, plural = verb_pair
173
+ correct_token = actual
174
+
175
+ p_sing, p_plur = score_next_token(prefix, singular, plural)
176
+
177
+ n_layers = model.cfg.n_layers
178
+ if n_layers_to_check is None:
179
+ n_layers_to_check = n_layers
180
+
181
+ layer_probs_correct = []
182
+ for layer in range(min(n_layers_to_check, n_layers)):
183
+ p_sing_patched, p_plur_patched = patch_layer_user(layer, user_sentence, bad_sentence, verb_pair)
184
+ p_correct = p_plur_patched if correct_token == plural else p_sing_patched
185
+ layer_probs_correct.append(p_correct)
186
+
187
+ return {
188
+ "user_sentence": user_sentence,
189
+ "prefix_used_for_scoring": prefix,
190
+ "verb_pair": verb_pair,
191
+ "actual_verb_in_sentence": actual,
192
+ "wrong_verb_used_for_bad_sentence": wrong,
193
+ "bad_sentence": bad_sentence,
194
+ "p_actual_token_raw": p_actual,
195
+ "p_wrong_token_raw": p_wrong,
196
+ "p_singular": p_sing,
197
+ "p_plural": p_plur,
198
+ "layer_probs_correct_after_patch": layer_probs_correct,
199
+ }