cdpearlman commited on
Commit
890f413
·
1 Parent(s): 43cf4ff

test(ablation): Verify multi-layer head ablation utility

Browse files
Files changed (1) hide show
  1. tests/test_multi_layer_ablation.py +56 -0
tests/test_multi_layer_ablation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+ import torch
5
+ import pytest
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ # Add project root to path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+
11
+ from utils.model_patterns import execute_forward_pass, execute_forward_pass_with_multi_layer_head_ablation
12
+
13
+ def test_multi_layer_ablation():
14
+ """
15
+ Verify that ablating heads across multiple layers works.
16
+ """
17
+ model_name = "gpt2"
18
+ prompt = "The quick brown fox jumps over the"
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(model_name)
22
+ model.eval()
23
+
24
+ config = {
25
+ "attention_modules": ["transformer.h.0.attn", "transformer.h.1.attn"],
26
+ "block_modules": ["transformer.h.0", "transformer.h.1"],
27
+ "norm_parameters": [],
28
+ "logit_lens_parameter": "transformer.ln_f.weight"
29
+ }
30
+
31
+ # 1. Baseline
32
+ baseline = execute_forward_pass(model, tokenizer, prompt, config)
33
+ baseline_prob = baseline['actual_output']['probability']
34
+
35
+ # 2. Ablate L0H0 and L1H1
36
+ # Note: heads_by_layer expects {layer_num: [head_indices]}
37
+ heads_to_ablate = {
38
+ 0: [0],
39
+ 1: [1]
40
+ }
41
+
42
+ ablated = execute_forward_pass_with_multi_layer_head_ablation(
43
+ model, tokenizer, prompt, config, heads_to_ablate
44
+ )
45
+ ablated_prob = ablated['actual_output']['probability']
46
+
47
+ print(f"Baseline: {baseline_prob}, Ablated: {ablated_prob}")
48
+
49
+ # Assert change
50
+ assert abs(baseline_prob - ablated_prob) > 1e-6
51
+
52
+ # Assert return structure contains ablation info
53
+ assert ablated['ablated_heads_by_layer'] == heads_to_ablate
54
+
55
+ if __name__ == "__main__":
56
+ test_multi_layer_ablation()