Spaces:
Sleeping
Sleeping
| """ | |
| Tests for scientifically accurate head ablation via pre-projection hooking. | |
| Verifies that ablation hooks are placed on the INPUT to c_proj (pre-projection), | |
| where per-head dimensions are still separable, rather than on the OUTPUT of the | |
| attention module (post-projection), where heads are mixed. | |
| """ | |
| import sys | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import pytest | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from utils.model_patterns import _find_output_proj_submodule | |
| def gpt2_model_and_tokenizer(): | |
| """Load GPT-2 once for all tests in this module.""" | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float32) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| return model, tokenizer | |
| except Exception as e: | |
| pytest.skip(f"Could not load GPT-2: {e}") | |
| class TestFindOutputProjSubmodule: | |
| def test_find_output_proj_gpt2(self, gpt2_model_and_tokenizer): | |
| """GPT-2 attention modules should have c_proj as the output projection.""" | |
| model, _ = gpt2_model_and_tokenizer | |
| attn_module = model.transformer.h[0].attn | |
| name, submodule = _find_output_proj_submodule(attn_module) | |
| assert name == "c_proj" | |
| assert submodule is attn_module.c_proj | |
| def test_find_output_proj_unknown_raises(self): | |
| """A plain nn.Module with no recognized projection children should raise ValueError.""" | |
| plain_module = nn.Module() | |
| plain_module.add_module("some_layer", nn.Linear(10, 10)) | |
| with pytest.raises(ValueError, match="No output projection found"): | |
| _find_output_proj_submodule(plain_module) | |
| class TestPreHookPlacement: | |
| def test_pre_hook_zeros_correct_dims(self, gpt2_model_and_tokenizer): | |
| """Pre-hook on c_proj receives input where per-head dims are separable. | |
| Zeroing head 0 dims [0:64] should leave [64:768] untouched.""" | |
| model, tokenizer = gpt2_model_and_tokenizer | |
| captured_input = {} | |
| def capture_pre_hook(module, args): | |
| captured_input['x'] = args[0].clone() | |
| return None # Don't modify | |
| hook = model.transformer.h[0].attn.c_proj.register_forward_pre_hook(capture_pre_hook) | |
| try: | |
| inputs = tokenizer("The cat sat on the mat", return_tensors="pt") | |
| with torch.no_grad(): | |
| model(**inputs, use_cache=False) | |
| finally: | |
| hook.remove() | |
| x = captured_input['x'] | |
| # Shape should be [batch, seq, 768] | |
| assert x.shape[-1] == 768 | |
| # Verify per-head structure: zeroing [0:64] leaves [64:768] intact | |
| x_modified = x.clone() | |
| x_modified[:, :, 0:64] = 0.0 | |
| # The rest should be exactly equal | |
| assert torch.equal(x_modified[:, :, 64:], x[:, :, 64:]) | |
| # The zeroed part should actually be zero | |
| assert torch.all(x_modified[:, :, 0:64] == 0.0) | |
| def test_ablation_changes_output(self, gpt2_model_and_tokenizer): | |
| """Ablating head 0 at layer 0 via pre-hook on c_proj should change logits.""" | |
| model, tokenizer = gpt2_model_and_tokenizer | |
| prompt = "The quick brown fox jumps over the" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Baseline | |
| with torch.no_grad(): | |
| baseline_logits = model(**inputs, use_cache=False).logits | |
| # Ablated: pre-hook zeros head 0 on layer 0's c_proj | |
| def ablation_pre_hook(module, args): | |
| x = args[0].clone() | |
| x[:, :, 0:64] = 0.0 | |
| return (x,) | |
| hook = model.transformer.h[0].attn.c_proj.register_forward_pre_hook(ablation_pre_hook) | |
| try: | |
| with torch.no_grad(): | |
| ablated_logits = model(**inputs, use_cache=False).logits | |
| finally: | |
| hook.remove() | |
| # Logits must differ | |
| assert not torch.allclose(baseline_logits, ablated_logits, atol=1e-6), \ | |
| "Ablation via pre-hook on c_proj did not change logits" | |