cdpearlman commited on
Commit
43cf4ff
·
1 Parent(s): 9bd0946

test(ablation): Add reproduction script for head ablation

Browse files
Files changed (1) hide show
  1. tests/reproduce_ablation.py +86 -0
tests/reproduce_ablation.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+ import torch
5
+ import pytest
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ # Add project root to path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+
11
+ from utils.model_patterns import execute_forward_pass, execute_forward_pass_with_head_ablation, load_model_and_get_patterns
12
+
13
+ def test_ablation_changes_output():
14
+ """
15
+ Verify that ablating a head changes the model output compared to a baseline run.
16
+ """
17
+ model_name = "gpt2" # Small model for testing
18
+ prompt = "The quick brown fox jumps over the"
19
+
20
+ print(f"Loading model: {model_name}")
21
+ # We can use the utility to load, but it prints a lot.
22
+ # Let's just use the load_model_and_get_patterns which handles config too
23
+ try:
24
+ module_patterns, param_patterns = load_model_and_get_patterns(model_name)
25
+ except Exception as e:
26
+ pytest.skip(f"Could not load model {model_name}: {e}")
27
+ return
28
+
29
+ # Re-load model/tokenizer locally to have direct access if needed,
30
+ # but the utils need the model object.
31
+ # load_model_and_get_patterns returns patterns, but we need the model object.
32
+ # Actually, execute_forward_pass takes (model, tokenizer, ...).
33
+ # load_model_and_get_patterns DOES NOT return the model. It loads it internally to get patterns.
34
+ # I need to load the model myself to pass it to execute_forward_pass.
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ model = AutoModelForCausalLM.from_pretrained(model_name)
38
+ model.eval()
39
+
40
+ # Define config for capture (we need to capture something to make the function work)
41
+ # For GPT-2, attention is transformer.h.{N}.attn
42
+ # We'll stick to the "module_patterns" logic implicitly or explicitly.
43
+ # Let's capture L0 attention
44
+ config = {
45
+ "attention_modules": ["transformer.h.0.attn"],
46
+ "block_modules": ["transformer.h.0"],
47
+ "norm_parameters": [],
48
+ "logit_lens_parameter": "transformer.ln_f.weight"
49
+ }
50
+
51
+ # 1. Baseline Run
52
+ print("Running baseline...")
53
+ baseline_result = execute_forward_pass(model, tokenizer, prompt, config)
54
+ baseline_top_token = baseline_result['actual_output']['token']
55
+ baseline_top_prob = baseline_result['actual_output']['probability']
56
+ print(f"Baseline Output: '{baseline_top_token}' ({baseline_top_prob:.4f})")
57
+
58
+ # 2. Ablated Run (Layer 0, Head 0)
59
+ print("Running ablation (L0H0)...")
60
+ ablation_result = execute_forward_pass_with_head_ablation(
61
+ model, tokenizer, prompt, config,
62
+ ablate_layer_num=0,
63
+ ablate_head_indices=[0]
64
+ )
65
+ ablated_top_token = ablation_result['actual_output']['token']
66
+ ablated_top_prob = ablation_result['actual_output']['probability']
67
+ print(f"Ablated Output: '{ablated_top_token}' ({ablated_top_prob:.4f})")
68
+
69
+ # 3. Assertions
70
+ # We expect the probability to change, even if the token doesn't (depending on head importance)
71
+ # Ideally, exact logit match should be false.
72
+
73
+ # Check if probabilities are different (using a small epsilon)
74
+ prob_diff = abs(baseline_top_prob - ablated_top_prob)
75
+ print(f"Probability Difference: {prob_diff}")
76
+
77
+ # We assert that there IS a difference.
78
+ # Note: If L0H0 is completely useless, this might fail. But usually it does something.
79
+ assert prob_diff > 1e-6, "Ablation of L0H0 did not change the top token probability at all!"
80
+
81
+ # Verify that the structure returned contains ablation info
82
+ assert ablation_result['ablated_layer'] == 0
83
+ assert ablation_result['ablated_heads'] == [0]
84
+
85
+ if __name__ == "__main__":
86
+ test_ablation_changes_output()