Spaces:
Sleeping
Sleeping
Commit ·
43cf4ff
1
Parent(s): 9bd0946
test(ablation): Add reproduction script for head ablation
Browse files- tests/reproduce_ablation.py +86 -0
tests/reproduce_ablation.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import pytest
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
# Add project root to path
|
| 9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
+
|
| 11 |
+
from utils.model_patterns import execute_forward_pass, execute_forward_pass_with_head_ablation, load_model_and_get_patterns
|
| 12 |
+
|
| 13 |
+
def test_ablation_changes_output():
|
| 14 |
+
"""
|
| 15 |
+
Verify that ablating a head changes the model output compared to a baseline run.
|
| 16 |
+
"""
|
| 17 |
+
model_name = "gpt2" # Small model for testing
|
| 18 |
+
prompt = "The quick brown fox jumps over the"
|
| 19 |
+
|
| 20 |
+
print(f"Loading model: {model_name}")
|
| 21 |
+
# We can use the utility to load, but it prints a lot.
|
| 22 |
+
# Let's just use the load_model_and_get_patterns which handles config too
|
| 23 |
+
try:
|
| 24 |
+
module_patterns, param_patterns = load_model_and_get_patterns(model_name)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
pytest.skip(f"Could not load model {model_name}: {e}")
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
# Re-load model/tokenizer locally to have direct access if needed,
|
| 30 |
+
# but the utils need the model object.
|
| 31 |
+
# load_model_and_get_patterns returns patterns, but we need the model object.
|
| 32 |
+
# Actually, execute_forward_pass takes (model, tokenizer, ...).
|
| 33 |
+
# load_model_and_get_patterns DOES NOT return the model. It loads it internally to get patterns.
|
| 34 |
+
# I need to load the model myself to pass it to execute_forward_pass.
|
| 35 |
+
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 37 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
# Define config for capture (we need to capture something to make the function work)
|
| 41 |
+
# For GPT-2, attention is transformer.h.{N}.attn
|
| 42 |
+
# We'll stick to the "module_patterns" logic implicitly or explicitly.
|
| 43 |
+
# Let's capture L0 attention
|
| 44 |
+
config = {
|
| 45 |
+
"attention_modules": ["transformer.h.0.attn"],
|
| 46 |
+
"block_modules": ["transformer.h.0"],
|
| 47 |
+
"norm_parameters": [],
|
| 48 |
+
"logit_lens_parameter": "transformer.ln_f.weight"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# 1. Baseline Run
|
| 52 |
+
print("Running baseline...")
|
| 53 |
+
baseline_result = execute_forward_pass(model, tokenizer, prompt, config)
|
| 54 |
+
baseline_top_token = baseline_result['actual_output']['token']
|
| 55 |
+
baseline_top_prob = baseline_result['actual_output']['probability']
|
| 56 |
+
print(f"Baseline Output: '{baseline_top_token}' ({baseline_top_prob:.4f})")
|
| 57 |
+
|
| 58 |
+
# 2. Ablated Run (Layer 0, Head 0)
|
| 59 |
+
print("Running ablation (L0H0)...")
|
| 60 |
+
ablation_result = execute_forward_pass_with_head_ablation(
|
| 61 |
+
model, tokenizer, prompt, config,
|
| 62 |
+
ablate_layer_num=0,
|
| 63 |
+
ablate_head_indices=[0]
|
| 64 |
+
)
|
| 65 |
+
ablated_top_token = ablation_result['actual_output']['token']
|
| 66 |
+
ablated_top_prob = ablation_result['actual_output']['probability']
|
| 67 |
+
print(f"Ablated Output: '{ablated_top_token}' ({ablated_top_prob:.4f})")
|
| 68 |
+
|
| 69 |
+
# 3. Assertions
|
| 70 |
+
# We expect the probability to change, even if the token doesn't (depending on head importance)
|
| 71 |
+
# Ideally, exact logit match should be false.
|
| 72 |
+
|
| 73 |
+
# Check if probabilities are different (using a small epsilon)
|
| 74 |
+
prob_diff = abs(baseline_top_prob - ablated_top_prob)
|
| 75 |
+
print(f"Probability Difference: {prob_diff}")
|
| 76 |
+
|
| 77 |
+
# We assert that there IS a difference.
|
| 78 |
+
# Note: If L0H0 is completely useless, this might fail. But usually it does something.
|
| 79 |
+
assert prob_diff > 1e-6, "Ablation of L0H0 did not change the top token probability at all!"
|
| 80 |
+
|
| 81 |
+
# Verify that the structure returned contains ablation info
|
| 82 |
+
assert ablation_result['ablated_layer'] == 0
|
| 83 |
+
assert ablation_result['ablated_heads'] == [0]
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
test_ablation_changes_output()
|