Activation-patching-app-gpt-2 / patching_logic.py
AmnaHassan's picture
Create patching_logic.py
b6d2a24 verified
import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
from datasets import load_dataset
# --- LOAD MODEL (Loaded only once when the module is imported) ---
model = HookedTransformer.from_pretrained("gpt2", device="cpu")
# ============================================================
# SAFE TRIMMING HELPERS FOR LONG INPUT
# ============================================================
def trim_to_context(text, keep_last_tokens=None):
"""
Trim `text` so its tokenized length <= model.cfg.n_ctx.
Keeps the LAST tokens in the sequence (closest to the verb).
"""
toks = model.tokenizer.encode(text)
n_ctx = model.cfg.n_ctx
max_keep = n_ctx if keep_last_tokens is None else min(keep_last_tokens, n_ctx)
if len(toks) > max_keep:
toks = toks[-max_keep:]
return model.tokenizer.decode(toks)
return text
def trim_pair_for_patching(good_sentence, bad_sentence):
"""
Trim both good & bad sentences to same-size windows <= n_ctx.
Ensures cache shapes match when patching.
"""
good_toks = model.tokenizer.encode(good_sentence)
bad_toks = model.tokenizer.encode(bad_sentence)
n_ctx = model.cfg.n_ctx
if len(good_toks) > n_ctx or len(bad_toks) > n_ctx:
good_toks = good_toks[-n_ctx:]
bad_toks = bad_toks[-n_ctx:]
return model.tokenizer.decode(good_toks), model.tokenizer.decode(bad_toks)
else:
return good_sentence, bad_sentence
# ============================================================
# UPDATED NEXT-TOKEN SCORING
# ============================================================
def score_next_token(sentence_prefix, token1, token2):
"""
Safely compute p(token1), p(token2) after a prefix.
Automatically trims prefix to fit GPT-2 context.
"""
max_prefix_toks = max(1, model.cfg.n_ctx - 1)
prefix_trimmed = trim_to_context(sentence_prefix, keep_last_tokens=max_prefix_toks)
tokens = model.to_tokens(prefix_trimmed)
logits = model(tokens)[0, -1]
probs = F.softmax(logits, dim=-1)
t1 = model.tokenizer.encode(" " + token1)[0]
t2 = model.tokenizer.encode(" " + token2)[0]
return float(probs[t1].detach().cpu().numpy()), float(probs[t2].detach().cpu().numpy())
# ============================================================
# DETECT VERB PAIR + BUILD BAD SENTENCE VARIANT
# ============================================================
def detect_and_build_variants(user_sentence):
verb_pairs = [
("has", "have"),
("is", "are"),
("was", "were"),
("does", "do"),
]
s = user_sentence.strip()
s_lower = s.lower()
for singular, plural in verb_pairs:
if singular in s_lower or plural in s_lower:
# Identify actual & wrong verb
if singular in s_lower:
split_token = singular
actual = singular
wrong = plural
else:
split_token = plural
actual = plural
wrong = singular
idx = s_lower.rfind(split_token)
prefix = s[:idx] if idx != -1 else s.rsplit(split_token, 1)[0]
token_len = len(split_token)
remainder = s[idx + token_len:]
bad_sentence = (prefix + wrong + remainder).strip()
return prefix, actual, wrong, bad_sentence, (singular, plural)
return None, None, None, None, None
# ============================================================
# UPDATED ACTIVATION PATCHING (SAFE)
# ============================================================
def patch_layer_user(layer, good_sentence, bad_sentence, verb_pair):
"""
Patch block.layer.attn.hook_z from good->bad.
Sentences are trimmed to same token length for safe patching.
"""
# Tokenize both sentences
good_toks = model.tokenizer.encode(good_sentence)
bad_toks = model.tokenizer.encode(bad_sentence)
# Trim to shortest length
min_len = min(len(good_toks), len(bad_toks))
# We now operate on raw tokens, ensuring we don't accidentally introduce BOS
# or other tokenizer-specific issues that cause length mismatches.
# The `run_with_cache` and `run_with_hooks` will be called with prepend_bos=False
# to maintain this consistency.
good_trimmed_tokens = good_toks[-min_len:]
bad_trimmed_tokens = bad_toks[-min_len:]
good_trimmed_str = model.tokenizer.decode(good_trimmed_tokens)
bad_trimmed_str = model.tokenizer.decode(bad_trimmed_tokens)
# Get cache for good sentence. Explicitly set prepend_bos=False.
_, cache_good = model.run_with_cache(good_trimmed_str, prepend_bos=False)
# Patch only matching sequence length
def patch_hook(value, hook):
# Since prepend_bos=False was used for run_with_cache,
# cache_good[hook.name] will have the sequence length of good_trimmed_str (min_len).
# The 'value' tensor in the hook will also have sequence length min_len because
# run_with_hooks is also called with prepend_bos=False.
# Thus, direct replacement is safe.
return cache_good[hook.name]
# Run patched logits. Explicitly set prepend_bos=False.
patched_logits = model.run_with_hooks(
bad_trimmed_str,
fwd_hooks=[(f"blocks.{layer}.attn.hook_z", patch_hook)],
prepend_bos=False
)
probs = F.softmax(patched_logits[0, -1], dim=-1)
singular, plural = verb_pair
# Tokenize for next-token prediction, these are usually just single tokens
t_sing = model.tokenizer.encode(" " + singular)[0]
t_plur = model.tokenizer.encode(" " + plural)[0]
return float(probs[t_sing].detach().cpu().numpy()), float(probs[t_plur].detach().cpu().numpy())
# ============================================================
# FULL USER PIPELINE
# ============================================================
def run_user_activation_pipeline(user_sentence, n_layers_to_check=None):
prefix, actual, wrong, bad_sentence, verb_pair = detect_and_build_variants(user_sentence)
if prefix is None:
return {"error": "No supported verb pair found (has/have, is/are, was/were, does/do)."}
p_actual, p_wrong = score_next_token(prefix, actual, wrong)
singular, plural = verb_pair
correct_token = actual
p_sing, p_plur = score_next_token(prefix, singular, plural)
n_layers = model.cfg.n_layers
if n_layers_to_check is None:
n_layers_to_check = n_layers
layer_probs_correct = []
for layer in range(min(n_layers_to_check, n_layers)):
p_sing_patched, p_plur_patched = patch_layer_user(layer, user_sentence, bad_sentence, verb_pair)
p_correct = p_plur_patched if correct_token == plural else p_sing_patched
layer_probs_correct.append(p_correct)
return {
"user_sentence": user_sentence,
"prefix_used_for_scoring": prefix,
"verb_pair": verb_pair,
"actual_verb_in_sentence": actual,
"wrong_verb_used_for_bad_sentence": wrong,
"bad_sentence": bad_sentence,
"p_actual_token_raw": p_actual,
"p_wrong_token_raw": p_wrong,
"p_singular": p_sing,
"p_plural": p_plur,
"layer_probs_correct_after_patch": layer_probs_correct,
}