cdpearlman commited on
Commit
ae08976
·
1 Parent(s): bb577e6

Fix multi-layer ablation to ablate all heads simultaneously

Browse files
app.py CHANGED
@@ -11,7 +11,9 @@ import json
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
  categorize_single_layer_heads, perform_beam_search,
14
- execute_forward_pass_with_head_ablation, evaluate_sequence_ablation, score_sequence,
 
 
15
  get_head_category_counts)
16
  from utils.head_detection import categorize_all_heads
17
  from utils.model_config import get_auto_selections
@@ -854,23 +856,13 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
854
  if not heads_by_layer:
855
  return html.Div("No valid heads selected.", style={'color': '#dc3545'})
856
 
857
- # Run ablation for each layer and track cumulative effect
858
- # For multi-layer ablation, we run sequentially but show the final effect
859
- # Note: For true simultaneous multi-layer ablation, model_patterns.py would need updates
860
- ablated_token = original_token
861
- ablated_prob = original_prob
862
-
863
- # Sort layers and run ablation for each
864
- sorted_layers = sorted(heads_by_layer.keys())
865
-
866
- for layer_num in sorted_layers:
867
- head_indices = heads_by_layer[layer_num]
868
- ablated_data = execute_forward_pass_with_head_ablation(
869
- model, tokenizer, sequence_text, config, layer_num, head_indices
870
- )
871
- ablated_output = ablated_data.get('actual_output', {})
872
- ablated_token = ablated_output.get('token', '')
873
- ablated_prob = ablated_output.get('probability', 0)
874
 
875
  # Format selected heads for display
876
  all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
 
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
  categorize_single_layer_heads, perform_beam_search,
14
+ execute_forward_pass_with_head_ablation,
15
+ execute_forward_pass_with_multi_layer_head_ablation,
16
+ evaluate_sequence_ablation, score_sequence,
17
  get_head_category_counts)
18
  from utils.head_detection import categorize_all_heads
19
  from utils.model_config import get_auto_selections
 
856
  if not heads_by_layer:
857
  return html.Div("No valid heads selected.", style={'color': '#dc3545'})
858
 
859
+ # Run ablation for all layers simultaneously in a single forward pass
860
+ ablated_data = execute_forward_pass_with_multi_layer_head_ablation(
861
+ model, tokenizer, sequence_text, config, heads_by_layer
862
+ )
863
+ ablated_output = ablated_data.get('actual_output', {})
864
+ ablated_token = ablated_output.get('token', '')
865
+ ablated_prob = ablated_output.get('probability', 0)
 
 
 
 
 
 
 
 
 
 
866
 
867
  # Format selected heads for display
868
  all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
tests/test_model_patterns.py CHANGED
@@ -4,12 +4,14 @@ Tests for utils/model_patterns.py
4
  Tests pure logic functions that don't require model loading:
5
  - merge_token_probabilities
6
  - safe_to_serializable
 
7
  """
8
 
9
  import pytest
10
  import torch
11
  import numpy as np
12
  from utils.model_patterns import merge_token_probabilities, safe_to_serializable
 
13
 
14
 
15
  class TestMergeTokenProbabilities:
@@ -178,3 +180,84 @@ class TestSafeToSerializableEdgeCases:
178
  assert result[1] == "string"
179
  assert result[2] == 42
180
  assert result[3] == {"key": [2]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  Tests pure logic functions that don't require model loading:
5
  - merge_token_probabilities
6
  - safe_to_serializable
7
+ - execute_forward_pass_with_multi_layer_head_ablation (import/signature tests)
8
  """
9
 
10
  import pytest
11
  import torch
12
  import numpy as np
13
  from utils.model_patterns import merge_token_probabilities, safe_to_serializable
14
+ from utils import execute_forward_pass_with_multi_layer_head_ablation
15
 
16
 
17
  class TestMergeTokenProbabilities:
 
180
  assert result[1] == "string"
181
  assert result[2] == 42
182
  assert result[3] == {"key": [2]}
183
+
184
+
185
+ class TestMultiLayerHeadAblation:
186
+ """Tests for execute_forward_pass_with_multi_layer_head_ablation function.
187
+
188
+ These tests verify the function exists, is importable, and has the expected signature.
189
+ Full integration tests would require loading a model.
190
+ """
191
+
192
+ def test_function_is_importable(self):
193
+ """Function should be importable from utils."""
194
+ from utils import execute_forward_pass_with_multi_layer_head_ablation
195
+ assert callable(execute_forward_pass_with_multi_layer_head_ablation)
196
+
197
+ def test_function_has_expected_signature(self):
198
+ """Function should accept model, tokenizer, prompt, config, heads_by_layer."""
199
+ import inspect
200
+ sig = inspect.signature(execute_forward_pass_with_multi_layer_head_ablation)
201
+ params = list(sig.parameters.keys())
202
+
203
+ assert 'model' in params
204
+ assert 'tokenizer' in params
205
+ assert 'prompt' in params
206
+ assert 'config' in params
207
+ assert 'heads_by_layer' in params
208
+
209
+ def test_heads_by_layer_type_annotation(self):
210
+ """heads_by_layer parameter should accept Dict[int, List[int]]."""
211
+ import inspect
212
+ from typing import Dict, List, get_type_hints
213
+
214
+ # Get annotations (may not be available at runtime if not using from __future__)
215
+ sig = inspect.signature(execute_forward_pass_with_multi_layer_head_ablation)
216
+ heads_param = sig.parameters.get('heads_by_layer')
217
+
218
+ # The parameter should exist
219
+ assert heads_param is not None
220
+ # Annotation may be a string or the actual type
221
+ if heads_param.annotation != inspect.Parameter.empty:
222
+ annotation_str = str(heads_param.annotation)
223
+ assert 'Dict' in annotation_str or 'dict' in annotation_str.lower()
224
+
225
+ def test_returns_error_for_no_modules(self):
226
+ """Should return error dict when config has no modules.
227
+
228
+ Note: This test uses a mock model that won't actually run forward pass.
229
+ The function should return early with an error before trying to run.
230
+ """
231
+ from unittest.mock import MagicMock
232
+
233
+ mock_model = MagicMock()
234
+ mock_tokenizer = MagicMock()
235
+ empty_config = {} # No modules specified
236
+ heads_by_layer = {0: [1]} # Non-empty to avoid early return
237
+
238
+ result = execute_forward_pass_with_multi_layer_head_ablation(
239
+ mock_model, mock_tokenizer, "test prompt", empty_config, heads_by_layer
240
+ )
241
+
242
+ assert 'error' in result
243
+ assert 'No modules specified' in result['error']
244
+
245
+ def test_returns_error_for_invalid_layer(self):
246
+ """Should return error when layer number doesn't match any module."""
247
+ from unittest.mock import MagicMock
248
+
249
+ mock_model = MagicMock()
250
+ mock_tokenizer = MagicMock()
251
+ # Config has layer 0 and 1, but we'll request layer 99
252
+ config = {
253
+ 'attention_modules': ['model.layers.0.self_attn', 'model.layers.1.self_attn'],
254
+ 'block_modules': ['model.layers.0', 'model.layers.1']
255
+ }
256
+ heads_by_layer = {99: [0, 1]} # Layer 99 doesn't exist
257
+
258
+ result = execute_forward_pass_with_multi_layer_head_ablation(
259
+ mock_model, mock_tokenizer, "test prompt", config, heads_by_layer
260
+ )
261
+
262
+ assert 'error' in result
263
+ assert '99' in result['error'] # Should mention the invalid layer
utils/__init__.py CHANGED
@@ -3,7 +3,9 @@ from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
3
  generate_bertviz_html, generate_category_bertviz_html,
4
  generate_head_view_with_categories, get_head_category_counts,
5
  get_check_token_probabilities, execute_forward_pass_with_layer_ablation,
6
- execute_forward_pass_with_head_ablation, merge_token_probabilities,
 
 
7
  compute_global_top5_tokens, detect_significant_probability_increases,
8
  compute_layer_wise_summaries, evaluate_sequence_ablation,
9
  compute_position_layer_matrix)
@@ -20,6 +22,7 @@ __all__ = [
20
  'execute_forward_pass',
21
  'execute_forward_pass_with_layer_ablation',
22
  'execute_forward_pass_with_head_ablation',
 
23
  'evaluate_sequence_ablation',
24
  'logit_lens_transformation',
25
  'extract_layer_data',
 
3
  generate_bertviz_html, generate_category_bertviz_html,
4
  generate_head_view_with_categories, get_head_category_counts,
5
  get_check_token_probabilities, execute_forward_pass_with_layer_ablation,
6
+ execute_forward_pass_with_head_ablation,
7
+ execute_forward_pass_with_multi_layer_head_ablation,
8
+ merge_token_probabilities,
9
  compute_global_top5_tokens, detect_significant_probability_increases,
10
  compute_layer_wise_summaries, evaluate_sequence_ablation,
11
  compute_position_layer_matrix)
 
22
  'execute_forward_pass',
23
  'execute_forward_pass_with_layer_ablation',
24
  'execute_forward_pass_with_head_ablation',
25
+ 'execute_forward_pass_with_multi_layer_head_ablation',
26
  'evaluate_sequence_ablation',
27
  'logit_lens_transformation',
28
  'extract_layer_data',
utils/model_patterns.py CHANGED
@@ -446,6 +446,197 @@ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, confi
446
  return result
447
 
448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
450
  ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
451
  """
 
446
  return result
447
 
448
 
449
+ def execute_forward_pass_with_multi_layer_head_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
450
+ heads_by_layer: Dict[int, List[int]]) -> Dict[str, Any]:
451
+ """
452
+ Execute forward pass with specific attention heads zeroed out across multiple layers simultaneously.
453
+
454
+ Args:
455
+ model: Loaded transformer model
456
+ tokenizer: Loaded tokenizer
457
+ prompt: Input text prompt
458
+ config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
459
+ heads_by_layer: Dict mapping layer numbers to lists of head indices to ablate
460
+ e.g., {0: [1, 3], 2: [0, 5]} ablates heads 1,3 in layer 0 and heads 0,5 in layer 2
461
+
462
+ Returns:
463
+ JSON-serializable dict with captured activations (with all specified heads ablated)
464
+ """
465
+ # Format ablation info for logging
466
+ ablation_info = ", ".join([f"L{layer}: H{heads}" for layer, heads in sorted(heads_by_layer.items())])
467
+ print(f"Executing forward pass with multi-layer head ablation: {ablation_info}")
468
+
469
+ # Handle empty heads_by_layer - just run normal forward pass
470
+ if not heads_by_layer:
471
+ from utils.model_patterns import execute_forward_pass
472
+ return execute_forward_pass(model, tokenizer, prompt, config)
473
+
474
+ # Extract module lists from config
475
+ attention_modules = config.get("attention_modules", [])
476
+ block_modules = config.get("block_modules", [])
477
+ norm_parameters = config.get("norm_parameters", [])
478
+ logit_lens_parameter = config.get("logit_lens_parameter")
479
+
480
+ all_modules = attention_modules + block_modules
481
+ if not all_modules:
482
+ return {"error": "No modules specified"}
483
+
484
+ # Build mapping from layer number to attention module name
485
+ layer_to_attention_module = {}
486
+ for mod_name in attention_modules:
487
+ layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
488
+ if layer_match:
489
+ layer_num = int(layer_match.group(1))
490
+ layer_to_attention_module[layer_num] = mod_name
491
+
492
+ # Find target attention modules for all layers to ablate
493
+ target_modules_to_heads = {} # module_name -> list of head indices
494
+ for layer_num, head_indices in heads_by_layer.items():
495
+ if layer_num in layer_to_attention_module:
496
+ mod_name = layer_to_attention_module[layer_num]
497
+ target_modules_to_heads[mod_name] = head_indices
498
+ else:
499
+ return {"error": f"Could not find attention module for layer {layer_num}"}
500
+
501
+ # Build IntervenableConfig
502
+ intervenable_representations = []
503
+ for mod_name in all_modules:
504
+ layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
505
+ if not layer_match:
506
+ return {"error": f"Invalid module name format: {mod_name}"}
507
+
508
+ if 'attn' in mod_name or 'attention' in mod_name:
509
+ component = 'attention_output'
510
+ else:
511
+ component = 'block_output'
512
+
513
+ intervenable_representations.append(
514
+ RepresentationConfig(layer=int(layer_match.group(1)), component=component, unit="pos")
515
+ )
516
+
517
+ intervenable_config = IntervenableConfig(
518
+ intervenable_representations=intervenable_representations
519
+ )
520
+ intervenable_model = IntervenableModel(intervenable_config, model)
521
+
522
+ # Prepare inputs
523
+ inputs = tokenizer(prompt, return_tensors="pt")
524
+
525
+ # Register hooks to capture activations
526
+ captured = {}
527
+ name_to_module = dict(intervenable_model.model.named_modules())
528
+
529
+ def make_hook(mod_name: str):
530
+ return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
531
+
532
+ # Create parameterized head ablation hook factory
533
+ def make_head_ablation_hook(target_mod_name: str, ablate_head_indices: List[int]):
534
+ """Create a hook that zeros out specific attention heads and captures the output."""
535
+ def head_ablation_hook(module, input, output):
536
+ ablated_output = output # Default to original output
537
+
538
+ if isinstance(output, tuple):
539
+ # Attention modules typically return (hidden_states, attention_weights, ...)
540
+ hidden_states = output[0] # [batch, seq_len, hidden_dim]
541
+
542
+ # Convert to tensor if needed
543
+ if not isinstance(hidden_states, torch.Tensor):
544
+ hidden_states = torch.tensor(hidden_states)
545
+
546
+ batch_size, seq_len, hidden_dim = hidden_states.shape
547
+
548
+ # Determine head dimension
549
+ num_heads = model.config.num_attention_heads
550
+ head_dim = hidden_dim // num_heads
551
+
552
+ # Reshape to [batch, seq_len, num_heads, head_dim]
553
+ hidden_states_reshaped = hidden_states.view(batch_size, seq_len, num_heads, head_dim)
554
+
555
+ # Zero out specified heads
556
+ for head_idx in ablate_head_indices:
557
+ if 0 <= head_idx < num_heads:
558
+ hidden_states_reshaped[:, :, head_idx, :] = 0.0
559
+
560
+ # Reshape back to [batch, seq_len, hidden_dim]
561
+ ablated_hidden = hidden_states_reshaped.view(batch_size, seq_len, hidden_dim)
562
+
563
+ # Reconstruct output tuple
564
+ if len(output) > 1:
565
+ ablated_output = (ablated_hidden,) + output[1:]
566
+ else:
567
+ ablated_output = (ablated_hidden,)
568
+
569
+ # Capture the ablated output
570
+ captured.update({target_mod_name: {"output": safe_to_serializable(ablated_output)}})
571
+
572
+ return ablated_output
573
+ return head_ablation_hook
574
+
575
+ # Register hooks
576
+ hooks = []
577
+ for mod_name in all_modules:
578
+ if mod_name in name_to_module:
579
+ if mod_name in target_modules_to_heads:
580
+ # Apply head ablation hook for this module
581
+ head_indices = target_modules_to_heads[mod_name]
582
+ hooks.append(name_to_module[mod_name].register_forward_hook(
583
+ make_head_ablation_hook(mod_name, head_indices)
584
+ ))
585
+ else:
586
+ # Regular capture hook
587
+ hooks.append(name_to_module[mod_name].register_forward_hook(make_hook(mod_name)))
588
+
589
+ # Execute forward pass
590
+ with torch.no_grad():
591
+ model_output = intervenable_model.model(**inputs, use_cache=False)
592
+
593
+ # Remove hooks
594
+ for hook in hooks:
595
+ hook.remove()
596
+
597
+ # Separate outputs by type
598
+ attention_outputs = {}
599
+ block_outputs = {}
600
+
601
+ for mod_name, output in captured.items():
602
+ if 'attn' in mod_name or 'attention' in mod_name:
603
+ attention_outputs[mod_name] = output
604
+ else:
605
+ block_outputs[mod_name] = output
606
+
607
+ # Capture normalization parameters
608
+ all_params = dict(model.named_parameters())
609
+ norm_data = [safe_to_serializable(all_params[p]) for p in norm_parameters if p in all_params]
610
+
611
+ # Extract predicted token from model output
612
+ actual_output = None
613
+ global_top5_tokens = []
614
+ try:
615
+ output_token, output_prob = get_actual_model_output(model_output, tokenizer)
616
+ actual_output = {"token": output_token, "probability": output_prob}
617
+ global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
618
+ except Exception as e:
619
+ print(f"Warning: Could not extract model output: {e}")
620
+
621
+ # Build output dictionary
622
+ result = {
623
+ "model": getattr(model.config, "name_or_path", "unknown"),
624
+ "prompt": prompt,
625
+ "input_ids": safe_to_serializable(inputs["input_ids"]),
626
+ "attention_modules": list(attention_outputs.keys()),
627
+ "attention_outputs": attention_outputs,
628
+ "block_modules": list(block_outputs.keys()),
629
+ "block_outputs": block_outputs,
630
+ "norm_parameters": norm_parameters,
631
+ "norm_data": norm_data,
632
+ "actual_output": actual_output,
633
+ "global_top5_tokens": global_top5_tokens,
634
+ "ablated_heads_by_layer": heads_by_layer # Include ablation info in result
635
+ }
636
+
637
+ return result
638
+
639
+
640
  def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
641
  ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
642
  """