cdpearlman commited on
Commit
ac1a7df
·
1 Parent(s): 8a752fe

Phase 1: Backend infrastructure for token merging, layer-wise tracking, and head ablation

Browse files

- Added merge_token_probabilities() to sum probabilities of tokens with/without leading space
- Added compute_global_top5_tokens() to get top 5 from final model output
- Updated logit_lens_transformation() to return merged probabilities
- Updated get_check_token_probabilities() to sum token variant probabilities
- Added detect_significant_probability_increases() to find layers with >=25%% relative increase
- Added _get_token_probabilities_for_layer() helper for tracking specific tokens
- Updated extract_layer_data() to track global top 5 across all layers with deltas
- Added compute_layer_wise_summaries() to create convenient summary structures
- Added execute_forward_pass_with_head_ablation() to zero out specific attention heads
- Updated execute_forward_pass() and execute_forward_pass_with_layer_ablation() to include global_top5_tokens
- All functions exported and properly integrated

Files changed (2) hide show
  1. utils/__init__.py +6 -1
  2. utils/model_patterns.py +421 -21
utils/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation
2
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
3
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
4
  from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
@@ -7,11 +7,16 @@ __all__ = [
7
  'load_model_and_get_patterns',
8
  'execute_forward_pass',
9
  'execute_forward_pass_with_layer_ablation',
 
10
  'logit_lens_transformation',
11
  'extract_layer_data',
12
  'generate_bertviz_html',
13
  'generate_category_bertviz_html',
14
  'get_check_token_probabilities',
 
 
 
 
15
  'get_model_family',
16
  'get_family_config',
17
  'get_auto_selections',
 
1
+ from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation, execute_forward_pass_with_head_ablation, merge_token_probabilities, compute_global_top5_tokens, detect_significant_probability_increases, compute_layer_wise_summaries
2
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
3
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
4
  from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
 
7
  'load_model_and_get_patterns',
8
  'execute_forward_pass',
9
  'execute_forward_pass_with_layer_ablation',
10
+ 'execute_forward_pass_with_head_ablation',
11
  'logit_lens_transformation',
12
  'extract_layer_data',
13
  'generate_bertviz_html',
14
  'generate_category_bertviz_html',
15
  'get_check_token_probabilities',
16
+ 'merge_token_probabilities',
17
+ 'compute_global_top5_tokens',
18
+ 'detect_significant_probability_increases',
19
+ 'compute_layer_wise_summaries',
20
  'get_model_family',
21
  'get_family_config',
22
  'get_auto_selections',
utils/model_patterns.py CHANGED
@@ -64,6 +64,63 @@ def safe_to_serializable(obj: Any) -> Any:
64
  return obj
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
68
  """
69
  Extract the predicted token from model's output.
@@ -181,9 +238,12 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
181
 
182
  # Extract predicted token from model output
183
  actual_output = None
 
184
  try:
185
  output_token, output_prob = get_actual_model_output(model_output, tokenizer)
186
  actual_output = {"token": output_token, "probability": output_prob}
 
 
187
  except Exception as e:
188
  print(f"Warning: Could not extract model output: {e}")
189
 
@@ -199,13 +259,187 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
199
  "norm_parameters": norm_parameters,
200
  "norm_data": norm_data,
201
  "logit_lens_parameter": logit_lens_parameter,
202
- "actual_output": actual_output
 
203
  }
204
 
205
  print(f"Captured {len(captured)} module outputs using PyVene")
206
  return result
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
210
  ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
211
  """
@@ -347,9 +581,11 @@ def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, conf
347
 
348
  # Extract predicted token from model output
349
  actual_output = None
 
350
  try:
351
  output_token, output_prob = get_actual_model_output(model_output, tokenizer)
352
  actual_output = {"token": output_token, "probability": output_prob}
 
353
  except Exception as e:
354
  print(f"Warning: Could not extract model output: {e}")
355
 
@@ -366,6 +602,7 @@ def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, conf
366
  "norm_data": norm_data,
367
  "logit_lens_parameter": logit_lens_parameter,
368
  "actual_output": actual_output,
 
369
  "ablated_layer": ablate_layer_num
370
  }
371
 
@@ -375,6 +612,7 @@ def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, conf
375
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
376
  """
377
  Transform layer output to top K token probabilities using logit lens.
 
378
 
379
  For standard logit lens, use block/layer outputs (residual stream), not component outputs.
380
  The residual stream contains the full hidden state with all accumulated information.
@@ -392,7 +630,7 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, lo
392
  top_k: Number of top tokens to return (default: 5)
393
 
394
  Returns:
395
- List of (token_string, probability) tuples for top K tokens
396
  """
397
  with torch.no_grad():
398
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
@@ -412,13 +650,18 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, lo
412
  # Step 3: Get probabilities via softmax
413
  probs = F.softmax(logits[0, -1, :], dim=-1)
414
 
415
- # Step 4: Extract top K tokens
416
- top_probs, top_indices = torch.topk(probs, k=top_k)
417
 
418
- return [
419
  (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
420
  for idx, prob in zip(top_indices, top_probs)
421
  ]
 
 
 
 
 
422
 
423
 
424
  def get_norm_layer_from_parameter(model, norm_parameter: Optional[str]) -> Optional[Any]:
@@ -458,6 +701,63 @@ def get_norm_layer_from_parameter(model, norm_parameter: Optional[str]) -> Optio
458
  return None
459
 
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer, top_k: int = 5) -> Optional[List[Tuple[str, float]]]:
462
  """
463
  Helper: Get top K tokens for a layer's block output.
@@ -486,8 +786,8 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
486
  """
487
  Collect check token probabilities across all layers.
488
 
489
- Tries both with and without leading space and uses the variant with higher probability.
490
- Returns layer numbers and probabilities for plotting.
491
  """
492
  if not check_token or not check_token.strip():
493
  return None
@@ -510,14 +810,15 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
510
  (' ' + check_token.strip(), tokenizer.encode(' ' + check_token.strip(), add_special_tokens=False))
511
  ]
512
 
513
- # Determine which variant to use (choose one with valid token IDs)
514
- target_token_id = None
515
  for variant_text, token_ids in token_variants:
516
  if token_ids:
517
- target_token_id = token_ids[-1] # Use last sub-token
518
- break
 
519
 
520
- if target_token_id is None:
521
  return None
522
 
523
  # Get norm parameter
@@ -526,7 +827,7 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
526
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
527
  lm_head = model.get_output_embeddings()
528
 
529
- # Collect probabilities for all layers
530
  layers = []
531
  probabilities = []
532
 
@@ -543,13 +844,15 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
543
 
544
  logits = lm_head(hidden)
545
  probs = F.softmax(logits[0, -1, :], dim=-1)
546
- prob = probs[target_token_id].item()
 
 
547
 
548
  layers.append(layer_num)
549
- probabilities.append(prob)
550
 
551
  return {
552
- 'token': tokenizer.decode([target_token_id], skip_special_tokens=False),
553
  'layers': layers,
554
  'probabilities': probabilities
555
  }
@@ -558,6 +861,43 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
558
  return None
559
 
560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  def _compute_certainty(probs: List[float]) -> float:
562
  """
563
  Compute normalized certainty from probability distribution.
@@ -655,12 +995,47 @@ def _get_top_attended_tokens(activation_data: Dict[str, Any], layer_num: int, to
655
  return None
656
 
657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
659
  """
660
  Extract layer-by-layer data for accordion display with top-5, deltas, certainty, and attention.
 
661
 
662
  Returns:
663
- List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas, certainty, top_attended_tokens
 
664
  """
665
  layer_modules = activation_data.get('block_modules', [])
666
  if not layer_modules:
@@ -677,8 +1052,14 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
677
  )
678
 
679
  logit_lens_enabled = activation_data.get('logit_lens_parameter') is not None
 
 
 
 
 
680
  layer_data = []
681
- prev_token_probs = {} # Track previous layer's token probabilities
 
682
 
683
  for layer_num, module_name in layer_info:
684
  top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if logit_lens_enabled else None
@@ -686,10 +1067,23 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
686
  # Get top-3 attended tokens for this layer
687
  top_attended = _get_top_attended_tokens(activation_data, layer_num, tokenizer, top_k=3)
688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  if top_tokens:
690
  top_token, top_prob = top_tokens[0]
691
 
692
- # Compute deltas vs previous layer
693
  deltas = {}
694
  for token, prob in top_tokens:
695
  prev_prob = prev_token_probs.get(token, 0.0)
@@ -708,11 +1102,14 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
708
  'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
709
  'deltas': deltas,
710
  'certainty': certainty,
711
- 'top_attended_tokens': top_attended # New: attention view
 
 
712
  })
713
 
714
  # Update previous layer probabilities
715
  prev_token_probs = {token: prob for token, prob in top_tokens}
 
716
  else:
717
  layer_data.append({
718
  'layer_num': layer_num,
@@ -723,8 +1120,11 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
723
  'top_5_tokens': [],
724
  'deltas': {},
725
  'certainty': 0.0,
726
- 'top_attended_tokens': top_attended
 
 
727
  })
 
728
 
729
  return layer_data
730
 
 
64
  return obj
65
 
66
 
67
+ def merge_token_probabilities(token_probs: List[Tuple[str, float]]) -> List[Tuple[str, float]]:
68
+ """
69
+ Merge tokens with and without leading space, summing their probabilities.
70
+
71
+ Example: [(" cat", 0.15), ("cat", 0.05), (" dog", 0.10)] -> [("cat", 0.20), ("dog", 0.10)]
72
+
73
+ Args:
74
+ token_probs: List of (token_string, probability) tuples
75
+
76
+ Returns:
77
+ List of (token_string, merged_probability) tuples, sorted by probability (descending)
78
+ """
79
+ merged = {} # Map from stripped token -> total probability
80
+
81
+ for token, prob in token_probs:
82
+ # Strip leading space to get canonical form
83
+ canonical = token.lstrip()
84
+ merged[canonical] = merged.get(canonical, 0.0) + prob
85
+
86
+ # Convert back to list and sort by probability (descending)
87
+ result = sorted(merged.items(), key=lambda x: x[1], reverse=True)
88
+ return result
89
+
90
+
91
+ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[Tuple[str, float]]:
92
+ """
93
+ Compute the global top-5 tokens from model's final output with merged probabilities.
94
+
95
+ Args:
96
+ model_output: Output from model(**inputs) containing logits
97
+ tokenizer: Tokenizer for decoding
98
+ top_k: Number of top tokens to return (default: 5)
99
+
100
+ Returns:
101
+ List of (token_string, probability) tuples for top K tokens with merged probabilities
102
+ """
103
+ with torch.no_grad():
104
+ # Get probabilities for next token (last position)
105
+ logits = model_output.logits[0, -1, :] # [vocab_size]
106
+ probs = F.softmax(logits, dim=-1)
107
+
108
+ # Get more candidates to account for merging (get 2x top_k)
109
+ top_probs, top_indices = torch.topk(probs, k=min(top_k * 2, len(probs)))
110
+
111
+ # Decode tokens
112
+ candidates = [
113
+ (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
114
+ for idx, prob in zip(top_indices, top_probs)
115
+ ]
116
+
117
+ # Merge tokens with/without leading space
118
+ merged = merge_token_probabilities(candidates)
119
+
120
+ # Return top K after merging
121
+ return merged[:top_k]
122
+
123
+
124
  def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
125
  """
126
  Extract the predicted token from model's output.
 
238
 
239
  # Extract predicted token from model output
240
  actual_output = None
241
+ global_top5_tokens = []
242
  try:
243
  output_token, output_prob = get_actual_model_output(model_output, tokenizer)
244
  actual_output = {"token": output_token, "probability": output_prob}
245
+ # Compute global top 5 tokens with merged probabilities
246
+ global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
247
  except Exception as e:
248
  print(f"Warning: Could not extract model output: {e}")
249
 
 
259
  "norm_parameters": norm_parameters,
260
  "norm_data": norm_data,
261
  "logit_lens_parameter": logit_lens_parameter,
262
+ "actual_output": actual_output,
263
+ "global_top5_tokens": global_top5_tokens # New: global top 5 from final output
264
  }
265
 
266
  print(f"Captured {len(captured)} module outputs using PyVene")
267
  return result
268
 
269
 
270
+ def execute_forward_pass_with_head_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
271
+ ablate_layer_num: int, ablate_head_indices: List[int]) -> Dict[str, Any]:
272
+ """
273
+ Execute forward pass with specific attention heads zeroed out.
274
+
275
+ Args:
276
+ model: Loaded transformer model
277
+ tokenizer: Loaded tokenizer
278
+ prompt: Input text prompt
279
+ config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
280
+ ablate_layer_num: Layer number containing heads to ablate
281
+ ablate_head_indices: List of head indices to zero out (e.g., [0, 2, 5])
282
+
283
+ Returns:
284
+ JSON-serializable dict with captured activations (with ablated heads)
285
+ """
286
+ print(f"Executing forward pass with head ablation: Layer {ablate_layer_num}, Heads {ablate_head_indices}")
287
+
288
+ # Extract module lists from config
289
+ attention_modules = config.get("attention_modules", [])
290
+ block_modules = config.get("block_modules", [])
291
+ norm_parameters = config.get("norm_parameters", [])
292
+ logit_lens_parameter = config.get("logit_lens_parameter")
293
+
294
+ all_modules = attention_modules + block_modules
295
+ if not all_modules:
296
+ return {"error": "No modules specified"}
297
+
298
+ # Find the target attention module for the layer to ablate
299
+ target_attention_module = None
300
+ for mod_name in attention_modules:
301
+ layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
302
+ if layer_match and int(layer_match.group(1)) == ablate_layer_num:
303
+ target_attention_module = mod_name
304
+ break
305
+
306
+ if not target_attention_module:
307
+ return {"error": f"Could not find attention module for layer {ablate_layer_num}"}
308
+
309
+ # Build IntervenableConfig
310
+ intervenable_representations = []
311
+ for mod_name in all_modules:
312
+ layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
313
+ if not layer_match:
314
+ return {"error": f"Invalid module name format: {mod_name}"}
315
+
316
+ if 'attn' in mod_name or 'attention' in mod_name:
317
+ component = 'attention_output'
318
+ else:
319
+ component = 'block_output'
320
+
321
+ intervenable_representations.append(
322
+ RepresentationConfig(layer=int(layer_match.group(1)), component=component, unit="pos")
323
+ )
324
+
325
+ intervenable_config = IntervenableConfig(
326
+ intervenable_representations=intervenable_representations
327
+ )
328
+ intervenable_model = IntervenableModel(intervenable_config, model)
329
+
330
+ # Prepare inputs
331
+ inputs = tokenizer(prompt, return_tensors="pt")
332
+
333
+ # Register hooks to capture activations
334
+ captured = {}
335
+ name_to_module = dict(intervenable_model.model.named_modules())
336
+
337
+ def make_hook(mod_name: str):
338
+ return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
339
+
340
+ # Create head ablation hook
341
+ def head_ablation_hook(module, input, output):
342
+ """Zero out specific attention heads in the output."""
343
+ if isinstance(output, tuple):
344
+ # Attention modules typically return (hidden_states, attention_weights, ...)
345
+ hidden_states = output[0] # [batch, seq_len, hidden_dim]
346
+
347
+ # Convert to tensor if needed
348
+ if not isinstance(hidden_states, torch.Tensor):
349
+ hidden_states = torch.tensor(hidden_states)
350
+
351
+ batch_size, seq_len, hidden_dim = hidden_states.shape
352
+
353
+ # Determine head dimension
354
+ # Assuming hidden_dim = num_heads * head_dim
355
+ # We need to get num_heads from the model config
356
+ num_heads = model.config.num_attention_heads
357
+ head_dim = hidden_dim // num_heads
358
+
359
+ # Reshape to [batch, seq_len, num_heads, head_dim]
360
+ hidden_states_reshaped = hidden_states.view(batch_size, seq_len, num_heads, head_dim)
361
+
362
+ # Zero out specified heads
363
+ for head_idx in ablate_head_indices:
364
+ if 0 <= head_idx < num_heads:
365
+ hidden_states_reshaped[:, :, head_idx, :] = 0.0
366
+
367
+ # Reshape back to [batch, seq_len, hidden_dim]
368
+ ablated_hidden = hidden_states_reshaped.view(batch_size, seq_len, hidden_dim)
369
+
370
+ # Reconstruct output tuple
371
+ if len(output) > 1:
372
+ return (ablated_hidden,) + output[1:]
373
+ else:
374
+ return (ablated_hidden,)
375
+ else:
376
+ # If output is not a tuple, just return as is (shouldn't happen for attention)
377
+ return output
378
+
379
+ # Register hooks
380
+ hooks = []
381
+ for mod_name in all_modules:
382
+ if mod_name in name_to_module:
383
+ if mod_name == target_attention_module:
384
+ # Apply head ablation hook
385
+ hooks.append(name_to_module[mod_name].register_forward_hook(head_ablation_hook))
386
+ else:
387
+ # Regular capture hook
388
+ hooks.append(name_to_module[mod_name].register_forward_hook(make_hook(mod_name)))
389
+
390
+ # Execute forward pass
391
+ with torch.no_grad():
392
+ model_output = intervenable_model.model(**inputs, use_cache=False)
393
+
394
+ # Remove hooks
395
+ for hook in hooks:
396
+ hook.remove()
397
+
398
+ # Separate outputs by type
399
+ attention_outputs = {}
400
+ block_outputs = {}
401
+
402
+ for mod_name, output in captured.items():
403
+ if 'attn' in mod_name or 'attention' in mod_name:
404
+ attention_outputs[mod_name] = output
405
+ else:
406
+ block_outputs[mod_name] = output
407
+
408
+ # Capture normalization parameters
409
+ all_params = dict(model.named_parameters())
410
+ norm_data = [safe_to_serializable(all_params[p]) for p in norm_parameters if p in all_params]
411
+
412
+ # Extract predicted token from model output
413
+ actual_output = None
414
+ global_top5_tokens = []
415
+ try:
416
+ output_token, output_prob = get_actual_model_output(model_output, tokenizer)
417
+ actual_output = {"token": output_token, "probability": output_prob}
418
+ global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
419
+ except Exception as e:
420
+ print(f"Warning: Could not extract model output: {e}")
421
+
422
+ # Build output dictionary
423
+ result = {
424
+ "model": getattr(model.config, "name_or_path", "unknown"),
425
+ "prompt": prompt,
426
+ "input_ids": safe_to_serializable(inputs["input_ids"]),
427
+ "attention_modules": list(attention_outputs.keys()),
428
+ "attention_outputs": attention_outputs,
429
+ "block_modules": list(block_outputs.keys()),
430
+ "block_outputs": block_outputs,
431
+ "norm_parameters": norm_parameters,
432
+ "norm_data": norm_data,
433
+ "logit_lens_parameter": logit_lens_parameter,
434
+ "actual_output": actual_output,
435
+ "global_top5_tokens": global_top5_tokens,
436
+ "ablated_layer": ablate_layer_num,
437
+ "ablated_heads": ablate_head_indices
438
+ }
439
+
440
+ return result
441
+
442
+
443
  def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
444
  ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
445
  """
 
581
 
582
  # Extract predicted token from model output
583
  actual_output = None
584
+ global_top5_tokens = []
585
  try:
586
  output_token, output_prob = get_actual_model_output(model_output, tokenizer)
587
  actual_output = {"token": output_token, "probability": output_prob}
588
+ global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
589
  except Exception as e:
590
  print(f"Warning: Could not extract model output: {e}")
591
 
 
602
  "norm_data": norm_data,
603
  "logit_lens_parameter": logit_lens_parameter,
604
  "actual_output": actual_output,
605
+ "global_top5_tokens": global_top5_tokens,
606
  "ablated_layer": ablate_layer_num
607
  }
608
 
 
612
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
613
  """
614
  Transform layer output to top K token probabilities using logit lens.
615
+ Returns merged probabilities (tokens with/without leading space are combined).
616
 
617
  For standard logit lens, use block/layer outputs (residual stream), not component outputs.
618
  The residual stream contains the full hidden state with all accumulated information.
 
630
  top_k: Number of top tokens to return (default: 5)
631
 
632
  Returns:
633
+ List of (token_string, probability) tuples for top K tokens with merged probabilities
634
  """
635
  with torch.no_grad():
636
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
 
650
  # Step 3: Get probabilities via softmax
651
  probs = F.softmax(logits[0, -1, :], dim=-1)
652
 
653
+ # Step 4: Extract top candidates (get 2x top_k to account for merging)
654
+ top_probs, top_indices = torch.topk(probs, k=min(top_k * 2, len(probs)))
655
 
656
+ candidates = [
657
  (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
658
  for idx, prob in zip(top_indices, top_probs)
659
  ]
660
+
661
+ # Step 5: Merge tokens with/without leading space
662
+ merged = merge_token_probabilities(candidates)
663
+
664
+ return merged[:top_k]
665
 
666
 
667
  def get_norm_layer_from_parameter(model, norm_parameter: Optional[str]) -> Optional[Any]:
 
701
  return None
702
 
703
 
704
+ def _get_token_probabilities_for_layer(activation_data: Dict[str, Any], module_name: str,
705
+ model, tokenizer, target_tokens: List[str]) -> Dict[str, float]:
706
+ """
707
+ Get probabilities for specific tokens at a given layer.
708
+
709
+ Args:
710
+ activation_data: Activation data from forward pass
711
+ module_name: Layer module name
712
+ model: Transformer model
713
+ tokenizer: Tokenizer
714
+ target_tokens: List of token strings to get probabilities for
715
+
716
+ Returns:
717
+ Dict mapping token -> probability (merged for variants with/without space)
718
+ """
719
+ try:
720
+ if module_name not in activation_data.get('block_outputs', {}):
721
+ return {}
722
+
723
+ layer_output = activation_data['block_outputs'][module_name]['output']
724
+ norm_params = activation_data.get('norm_parameters', [])
725
+ norm_parameter = norm_params[0] if norm_params else None
726
+ final_norm = get_norm_layer_from_parameter(model, norm_parameter)
727
+ lm_head = model.get_output_embeddings()
728
+
729
+ with torch.no_grad():
730
+ hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
731
+ if hidden.dim() == 4:
732
+ hidden = hidden.squeeze(0)
733
+
734
+ if final_norm is not None:
735
+ hidden = final_norm(hidden)
736
+
737
+ logits = lm_head(hidden)
738
+ probs = F.softmax(logits[0, -1, :], dim=-1)
739
+
740
+ # For each target token, get probabilities for both variants (with/without space)
741
+ token_probs = {}
742
+ for token in target_tokens:
743
+ # Try both variants and sum probabilities
744
+ variants = [token, ' ' + token]
745
+ total_prob = 0.0
746
+
747
+ for variant in variants:
748
+ token_ids = tokenizer.encode(variant, add_special_tokens=False)
749
+ if token_ids:
750
+ tid = token_ids[-1] # Use last sub-token
751
+ total_prob += probs[tid].item()
752
+
753
+ token_probs[token] = total_prob
754
+
755
+ return token_probs
756
+ except Exception as e:
757
+ print(f"Warning: Could not compute token probabilities for {module_name}: {e}")
758
+ return {}
759
+
760
+
761
  def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer, top_k: int = 5) -> Optional[List[Tuple[str, float]]]:
762
  """
763
  Helper: Get top K tokens for a layer's block output.
 
786
  """
787
  Collect check token probabilities across all layers.
788
 
789
+ Sums probabilities of token variants (with and without leading space).
790
+ Returns layer numbers and merged probabilities for plotting.
791
  """
792
  if not check_token or not check_token.strip():
793
  return None
 
810
  (' ' + check_token.strip(), tokenizer.encode(' ' + check_token.strip(), add_special_tokens=False))
811
  ]
812
 
813
+ # Get token IDs for both variants (if they exist and differ)
814
+ target_token_ids = []
815
  for variant_text, token_ids in token_variants:
816
  if token_ids:
817
+ tid = token_ids[-1] # Use last sub-token
818
+ if tid not in target_token_ids:
819
+ target_token_ids.append(tid)
820
 
821
+ if not target_token_ids:
822
  return None
823
 
824
  # Get norm parameter
 
827
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
828
  lm_head = model.get_output_embeddings()
829
 
830
+ # Collect probabilities for all layers (sum both variants)
831
  layers = []
832
  probabilities = []
833
 
 
844
 
845
  logits = lm_head(hidden)
846
  probs = F.softmax(logits[0, -1, :], dim=-1)
847
+
848
+ # Sum probabilities of all variants
849
+ merged_prob = sum(probs[tid].item() for tid in target_token_ids)
850
 
851
  layers.append(layer_num)
852
+ probabilities.append(merged_prob)
853
 
854
  return {
855
+ 'token': check_token.strip(), # Return canonical form without leading space
856
  'layers': layers,
857
  'probabilities': probabilities
858
  }
 
861
  return None
862
 
863
 
864
+ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
865
+ layer_wise_deltas: Dict[int, Dict[str, float]],
866
+ threshold: float = 0.25) -> List[int]:
867
+ """
868
+ Detect layers where any global top 5 token has significant probability increase.
869
+
870
+ A layer is significant if any token has ≥25% relative increase from previous layer.
871
+ Example: 0.20 → 0.25 is (0.25-0.20)/0.20 = 25% increase.
872
+
873
+ Args:
874
+ layer_wise_probs: Dict mapping layer_num → {token: prob}
875
+ layer_wise_deltas: Dict mapping layer_num → {token: delta}
876
+ threshold: Relative increase threshold (default: 0.25 = 25%)
877
+
878
+ Returns:
879
+ List of layer numbers with significant increases
880
+ """
881
+ significant_layers = []
882
+
883
+ for layer_num in sorted(layer_wise_probs.keys()):
884
+ probs = layer_wise_probs[layer_num]
885
+ deltas = layer_wise_deltas.get(layer_num, {})
886
+
887
+ for token, prob in probs.items():
888
+ delta = deltas.get(token, 0.0)
889
+ prev_prob = prob - delta
890
+
891
+ # Check for significant relative increase (avoid division by zero)
892
+ if prev_prob > 1e-6 and delta > 0:
893
+ relative_increase = delta / prev_prob
894
+ if relative_increase >= threshold:
895
+ significant_layers.append(layer_num)
896
+ break # Only need to flag layer once
897
+
898
+ return significant_layers
899
+
900
+
901
  def _compute_certainty(probs: List[float]) -> float:
902
  """
903
  Compute normalized certainty from probability distribution.
 
995
  return None
996
 
997
 
998
+ def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]]) -> Dict[str, Any]:
999
+ """
1000
+ Compute summary structures from layer data for easy access.
1001
+
1002
+ Args:
1003
+ layer_data: List of layer data dicts from extract_layer_data()
1004
+
1005
+ Returns:
1006
+ Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
1007
+ """
1008
+ layer_wise_top5_probs = {} # layer_num -> {token: prob}
1009
+ layer_wise_top5_deltas = {} # layer_num -> {token: delta}
1010
+
1011
+ for layer_info in layer_data:
1012
+ layer_num = layer_info.get('layer_num')
1013
+ if layer_num is not None:
1014
+ layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
1015
+ layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
1016
+
1017
+ # Detect significant layers
1018
+ significant_layers = detect_significant_probability_increases(
1019
+ layer_wise_top5_probs,
1020
+ layer_wise_top5_deltas,
1021
+ threshold=0.25
1022
+ )
1023
+
1024
+ return {
1025
+ 'layer_wise_top5_probs': layer_wise_top5_probs,
1026
+ 'layer_wise_top5_deltas': layer_wise_top5_deltas,
1027
+ 'significant_layers': significant_layers
1028
+ }
1029
+
1030
+
1031
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
1032
  """
1033
  Extract layer-by-layer data for accordion display with top-5, deltas, certainty, and attention.
1034
+ Also tracks global top 5 tokens across all layers.
1035
 
1036
  Returns:
1037
+ List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas, certainty, top_attended_tokens,
1038
+ global_top5_probs, global_top5_deltas
1039
  """
1040
  layer_modules = activation_data.get('block_modules', [])
1041
  if not layer_modules:
 
1052
  )
1053
 
1054
  logit_lens_enabled = activation_data.get('logit_lens_parameter') is not None
1055
+
1056
+ # Get global top 5 tokens from final output
1057
+ global_top5_tokens = activation_data.get('global_top5_tokens', [])
1058
+ global_top5_token_names = [token for token, _ in global_top5_tokens]
1059
+
1060
  layer_data = []
1061
+ prev_token_probs = {} # Track previous layer's token probabilities (layer's own top 5)
1062
+ prev_global_probs = {} # Track previous layer's global top 5 probabilities
1063
 
1064
  for layer_num, module_name in layer_info:
1065
  top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if logit_lens_enabled else None
 
1067
  # Get top-3 attended tokens for this layer
1068
  top_attended = _get_top_attended_tokens(activation_data, layer_num, tokenizer, top_k=3)
1069
 
1070
+ # Get probabilities for global top 5 tokens at this layer
1071
+ global_top5_probs = {}
1072
+ global_top5_deltas = {}
1073
+ if logit_lens_enabled and global_top5_token_names:
1074
+ global_top5_probs = _get_token_probabilities_for_layer(
1075
+ activation_data, module_name, model, tokenizer, global_top5_token_names
1076
+ )
1077
+ # Compute deltas for global top 5
1078
+ for token in global_top5_token_names:
1079
+ current_prob = global_top5_probs.get(token, 0.0)
1080
+ prev_prob = prev_global_probs.get(token, 0.0)
1081
+ global_top5_deltas[token] = current_prob - prev_prob
1082
+
1083
  if top_tokens:
1084
  top_token, top_prob = top_tokens[0]
1085
 
1086
+ # Compute deltas vs previous layer (for layer's own top 5)
1087
  deltas = {}
1088
  for token, prob in top_tokens:
1089
  prev_prob = prev_token_probs.get(token, 0.0)
 
1102
  'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
1103
  'deltas': deltas,
1104
  'certainty': certainty,
1105
+ 'top_attended_tokens': top_attended,
1106
+ 'global_top5_probs': global_top5_probs, # New: global top 5 probs at this layer
1107
+ 'global_top5_deltas': global_top5_deltas # New: global top 5 deltas
1108
  })
1109
 
1110
  # Update previous layer probabilities
1111
  prev_token_probs = {token: prob for token, prob in top_tokens}
1112
+ prev_global_probs = global_top5_probs.copy()
1113
  else:
1114
  layer_data.append({
1115
  'layer_num': layer_num,
 
1120
  'top_5_tokens': [],
1121
  'deltas': {},
1122
  'certainty': 0.0,
1123
+ 'top_attended_tokens': top_attended,
1124
+ 'global_top5_probs': {},
1125
+ 'global_top5_deltas': {}
1126
  })
1127
+ prev_global_probs = {}
1128
 
1129
  return layer_data
1130