| | import torch |
| | import torch.nn.functional as F |
| | from transformer_lens import HookedTransformer |
| | from datasets import load_dataset |
| |
|
| | |
| | model = HookedTransformer.from_pretrained("gpt2", device="cpu") |
| |
|
| | |
| | |
| | |
| |
|
| | def trim_to_context(text, keep_last_tokens=None): |
| | """ |
| | Trim `text` so its tokenized length <= model.cfg.n_ctx. |
| | Keeps the LAST tokens in the sequence (closest to the verb). |
| | """ |
| | toks = model.tokenizer.encode(text) |
| | n_ctx = model.cfg.n_ctx |
| | max_keep = n_ctx if keep_last_tokens is None else min(keep_last_tokens, n_ctx) |
| |
|
| | if len(toks) > max_keep: |
| | toks = toks[-max_keep:] |
| | return model.tokenizer.decode(toks) |
| | return text |
| |
|
| |
|
| | def trim_pair_for_patching(good_sentence, bad_sentence): |
| | """ |
| | Trim both good & bad sentences to same-size windows <= n_ctx. |
| | Ensures cache shapes match when patching. |
| | """ |
| | good_toks = model.tokenizer.encode(good_sentence) |
| | bad_toks = model.tokenizer.encode(bad_sentence) |
| | n_ctx = model.cfg.n_ctx |
| |
|
| | if len(good_toks) > n_ctx or len(bad_toks) > n_ctx: |
| | good_toks = good_toks[-n_ctx:] |
| | bad_toks = bad_toks[-n_ctx:] |
| | return model.tokenizer.decode(good_toks), model.tokenizer.decode(bad_toks) |
| | else: |
| | return good_sentence, bad_sentence |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def score_next_token(sentence_prefix, token1, token2): |
| | """ |
| | Safely compute p(token1), p(token2) after a prefix. |
| | Automatically trims prefix to fit GPT-2 context. |
| | """ |
| | max_prefix_toks = max(1, model.cfg.n_ctx - 1) |
| | prefix_trimmed = trim_to_context(sentence_prefix, keep_last_tokens=max_prefix_toks) |
| |
|
| | tokens = model.to_tokens(prefix_trimmed) |
| | logits = model(tokens)[0, -1] |
| | probs = F.softmax(logits, dim=-1) |
| |
|
| | t1 = model.tokenizer.encode(" " + token1)[0] |
| | t2 = model.tokenizer.encode(" " + token2)[0] |
| |
|
| | return float(probs[t1].detach().cpu().numpy()), float(probs[t2].detach().cpu().numpy()) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def detect_and_build_variants(user_sentence): |
| | verb_pairs = [ |
| | ("has", "have"), |
| | ("is", "are"), |
| | ("was", "were"), |
| | ("does", "do"), |
| | ] |
| |
|
| | s = user_sentence.strip() |
| | s_lower = s.lower() |
| |
|
| | for singular, plural in verb_pairs: |
| | if singular in s_lower or plural in s_lower: |
| | |
| | if singular in s_lower: |
| | split_token = singular |
| | actual = singular |
| | wrong = plural |
| | else: |
| | split_token = plural |
| | actual = plural |
| | wrong = singular |
| |
|
| | idx = s_lower.rfind(split_token) |
| | prefix = s[:idx] if idx != -1 else s.rsplit(split_token, 1)[0] |
| | token_len = len(split_token) |
| | remainder = s[idx + token_len:] |
| |
|
| | bad_sentence = (prefix + wrong + remainder).strip() |
| |
|
| | return prefix, actual, wrong, bad_sentence, (singular, plural) |
| |
|
| | return None, None, None, None, None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def patch_layer_user(layer, good_sentence, bad_sentence, verb_pair): |
| | """ |
| | Patch block.layer.attn.hook_z from good->bad. |
| | Sentences are trimmed to same token length for safe patching. |
| | """ |
| | |
| | good_toks = model.tokenizer.encode(good_sentence) |
| | bad_toks = model.tokenizer.encode(bad_sentence) |
| |
|
| | |
| | min_len = min(len(good_toks), len(bad_toks)) |
| |
|
| | |
| | |
| | |
| | |
| | good_trimmed_tokens = good_toks[-min_len:] |
| | bad_trimmed_tokens = bad_toks[-min_len:] |
| |
|
| | good_trimmed_str = model.tokenizer.decode(good_trimmed_tokens) |
| | bad_trimmed_str = model.tokenizer.decode(bad_trimmed_tokens) |
| |
|
| | |
| | _, cache_good = model.run_with_cache(good_trimmed_str, prepend_bos=False) |
| |
|
| | |
| | def patch_hook(value, hook): |
| | |
| | |
| | |
| | |
| | |
| | return cache_good[hook.name] |
| |
|
| | |
| | patched_logits = model.run_with_hooks( |
| | bad_trimmed_str, |
| | fwd_hooks=[(f"blocks.{layer}.attn.hook_z", patch_hook)], |
| | prepend_bos=False |
| | ) |
| |
|
| | probs = F.softmax(patched_logits[0, -1], dim=-1) |
| | singular, plural = verb_pair |
| | |
| | t_sing = model.tokenizer.encode(" " + singular)[0] |
| | t_plur = model.tokenizer.encode(" " + plural)[0] |
| |
|
| | return float(probs[t_sing].detach().cpu().numpy()), float(probs[t_plur].detach().cpu().numpy()) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run_user_activation_pipeline(user_sentence, n_layers_to_check=None): |
| | prefix, actual, wrong, bad_sentence, verb_pair = detect_and_build_variants(user_sentence) |
| |
|
| | if prefix is None: |
| | return {"error": "No supported verb pair found (has/have, is/are, was/were, does/do)."} |
| |
|
| | p_actual, p_wrong = score_next_token(prefix, actual, wrong) |
| |
|
| | singular, plural = verb_pair |
| | correct_token = actual |
| |
|
| | p_sing, p_plur = score_next_token(prefix, singular, plural) |
| |
|
| | n_layers = model.cfg.n_layers |
| | if n_layers_to_check is None: |
| | n_layers_to_check = n_layers |
| |
|
| | layer_probs_correct = [] |
| | for layer in range(min(n_layers_to_check, n_layers)): |
| | p_sing_patched, p_plur_patched = patch_layer_user(layer, user_sentence, bad_sentence, verb_pair) |
| | p_correct = p_plur_patched if correct_token == plural else p_sing_patched |
| | layer_probs_correct.append(p_correct) |
| |
|
| | return { |
| | "user_sentence": user_sentence, |
| | "prefix_used_for_scoring": prefix, |
| | "verb_pair": verb_pair, |
| | "actual_verb_in_sentence": actual, |
| | "wrong_verb_used_for_bad_sentence": wrong, |
| | "bad_sentence": bad_sentence, |
| | "p_actual_token_raw": p_actual, |
| | "p_wrong_token_raw": p_wrong, |
| | "p_singular": p_sing, |
| | "p_plural": p_plur, |
| | "layer_probs_correct_after_patch": layer_probs_correct, |
| | } |