cdpearlman commited on
Commit
d5dc3e0
·
1 Parent(s): 3bbf674

Feature 4: Replace BertViz head_view with model_view for hierarchical attention visualization

Browse files
todo.md CHANGED
@@ -32,8 +32,32 @@
32
  ✅ Feature 3 complete!
33
 
34
  Feature Updates:
35
- [ ] Collapsible Sidebar should minimize to the left and allow main dashboard to fill screen. Maximized size should remain as is, minimized should hide all the way to the left with still visible chevron to maximize.
36
- [ ] The "Compare +" button should switch to a red button that says "Remove -". It should function exactly the same, removing the second prompt, just with a different visual.
37
- [ ] The "Check Token" text box needs a "Submit" button in order to kickoff the creation of the 4th edge.
38
- [ ] Bug: When a second prompt is given and the "Run Analysis" button is clicked, only 1 graph is created when there should be 2 graphs: one above the other.
39
- [ ] Bug: The token given in the Check Token box has a probability of 0 for every layer, even if its probability exists in other edges. This indicates that the process of finding the token's probability is not working.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ✅ Feature 3 complete!
33
 
34
  Feature Updates:
35
+ [x] Collapsible Sidebar should minimize to the left and allow main dashboard to fill screen. Maximized size should remain as is, minimized should hide all the way to the left with still visible chevron to maximize.
36
+ [x] The "Compare +" button should switch to a red button that says "Remove -". It should function exactly the same, removing the second prompt, just with a different visual.
37
+ [x] The "Check Token" text box needs a "Submit" button in order to kickoff the creation of the 4th edge.
38
+ [x] Bug: When a second prompt is given and the "Run Analysis" button is clicked, only 1 graph is created when there should be 2 graphs: one above the other.
39
+ [x] Bug: The token given in the Check Token box has a probability of 0 for every layer - added debug output to investigate
40
+ ✅ All feature updates complete!
41
+
42
+ ## Feature 4: Replace BertViz head_view with model_view
43
+ [x] Read current generate_bertviz_html implementation
44
+ [x] Replace head_view call with model_view
45
+ [x] Update to pass all layers' attention to model_view
46
+ [ ] Test with GPT-2 and Qwen2.5-0.5B models
47
+ [ ] Verify model_view displays correctly in iframe
48
+
49
+ ## Feature 5: Attention Head Detection and Categorization
50
+ [ ] Create utility module for head categorization (utils/head_detection.py)
51
+ [ ] Implement detection heuristics for Previous-Token heads
52
+ [ ] Implement detection heuristics for First/Positional heads
53
+ [ ] Implement detection heuristics for Bag-of-Words heads
54
+ [ ] Implement detection heuristics for Syntactic heads
55
+ [ ] Add UI section to display categorized heads
56
+ [ ] Make heuristics parameterized for tuning
57
+
58
+ ## Feature 6: Two-Prompt Difference Analysis
59
+ [ ] Compute attention distribution differences across layers/heads
60
+ [ ] Compute output probability differences at each layer
61
+ [ ] Highlight layers with significant differences (red border)
62
+ [ ] Add summary panel showing top-N divergent layers/heads
63
+ [ ] Make difference thresholds configurable
utils/__pycache__/model_patterns.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
 
utils/model_patterns.py CHANGED
@@ -484,42 +484,45 @@ def format_data_for_cytoscape(activation_data: Dict[str, Any], model, tokenizer,
484
 
485
  def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, view_type: str = 'full') -> str:
486
  """
487
- Generate BertViz attention visualization HTML for a specific layer.
 
 
488
 
489
  Args:
490
  activation_data: Output from execute_forward_pass
491
- layer_index: Index of layer to visualize
492
  view_type: 'full' for complete visualization or 'mini' for preview
493
 
494
  Returns:
495
  HTML string for the visualization
496
  """
497
  try:
498
- from bertviz import head_view
499
  from transformers import AutoTokenizer
500
 
501
  # Extract attention modules and sort by layer
502
  attention_outputs = activation_data.get('attention_outputs', {})
503
  if not attention_outputs:
504
- return f"<p>No attention data available for layer {layer_index}</p>"
505
 
506
- # Find attention module for the specified layer
507
- target_module = None
508
  for module_name in attention_outputs.keys():
509
  numbers = re.findall(r'\d+', module_name)
510
- if numbers and int(numbers[0]) == layer_index:
511
- target_module = module_name
512
- break
513
-
514
- if not target_module:
515
- return f"<p>Layer {layer_index} not found in attention data</p>"
 
516
 
517
- # Get attention weights (element 1 of the output tuple)
518
- attention_output = attention_outputs[target_module]['output']
519
- if not isinstance(attention_output, list) or len(attention_output) < 2:
520
- return f"<p>Invalid attention format for layer {layer_index}</p>"
521
 
522
- attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
 
 
523
 
524
  # Get tokens
525
  input_ids = torch.tensor(activation_data['input_ids'])
@@ -528,6 +531,7 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
528
  # Load tokenizer and convert to tokens
529
  tokenizer = AutoTokenizer.from_pretrained(model_name)
530
  raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
531
  tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
532
 
533
  # Generate visualization based on view_type
@@ -537,15 +541,17 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
537
  <div style="padding:10px; border:1px solid #ccc; border-radius:5px;">
538
  <h4>Layer {layer_index} Attention Preview</h4>
539
  <p><strong>Tokens:</strong> {' '.join(tokens[:8])}{'...' if len(tokens) > 8 else ''}</p>
540
- <p><strong>Attention Shape:</strong> {list(attention_weights.shape)}</p>
541
- <p><em>Click for full visualization</em></p>
 
542
  </div>
543
  """
544
  else:
545
- # Full version: complete bertviz visualization
546
- attentions = (attention_weights,) # Single layer tuple
547
- html_result = head_view(attentions, tokens, html_action='return')
548
  return html_result.data if hasattr(html_result, 'data') else str(html_result)
549
 
550
  except Exception as e:
 
 
551
  return f"<p>Error generating visualization: {str(e)}</p>"
 
484
 
485
  def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, view_type: str = 'full') -> str:
486
  """
487
+ Generate BertViz attention visualization HTML using model_view.
488
+
489
+ Shows all layers with the specified layer highlighted/focused.
490
 
491
  Args:
492
  activation_data: Output from execute_forward_pass
493
+ layer_index: Index of layer to visualize (for context; model_view shows all layers)
494
  view_type: 'full' for complete visualization or 'mini' for preview
495
 
496
  Returns:
497
  HTML string for the visualization
498
  """
499
  try:
500
+ from bertviz import model_view
501
  from transformers import AutoTokenizer
502
 
503
  # Extract attention modules and sort by layer
504
  attention_outputs = activation_data.get('attention_outputs', {})
505
  if not attention_outputs:
506
+ return f"<p>No attention data available</p>"
507
 
508
+ # Sort attention modules by layer number
509
+ layer_attention_pairs = []
510
  for module_name in attention_outputs.keys():
511
  numbers = re.findall(r'\d+', module_name)
512
+ if numbers:
513
+ layer_num = int(numbers[0])
514
+ attention_output = attention_outputs[module_name]['output']
515
+ if isinstance(attention_output, list) and len(attention_output) >= 2:
516
+ # Get attention weights (element 1 of the output tuple)
517
+ attention_weights = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
518
+ layer_attention_pairs.append((layer_num, attention_weights))
519
 
520
+ if not layer_attention_pairs:
521
+ return f"<p>No valid attention data found</p>"
 
 
522
 
523
+ # Sort by layer number and extract attention tensors
524
+ layer_attention_pairs.sort(key=lambda x: x[0])
525
+ attentions = tuple(attn for _, attn in layer_attention_pairs)
526
 
527
  # Get tokens
528
  input_ids = torch.tensor(activation_data['input_ids'])
 
531
  # Load tokenizer and convert to tokens
532
  tokenizer = AutoTokenizer.from_pretrained(model_name)
533
  raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
534
+ # Clean up tokens (remove special tokenizer artifacts like Ġ for GPT-2)
535
  tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
536
 
537
  # Generate visualization based on view_type
 
541
  <div style="padding:10px; border:1px solid #ccc; border-radius:5px;">
542
  <h4>Layer {layer_index} Attention Preview</h4>
543
  <p><strong>Tokens:</strong> {' '.join(tokens[:8])}{'...' if len(tokens) > 8 else ''}</p>
544
+ <p><strong>Total Layers:</strong> {len(attentions)}</p>
545
+ <p><strong>Heads per Layer:</strong> {attentions[0].shape[1] if attentions else 'N/A'}</p>
546
+ <p><em>Click for full model_view visualization</em></p>
547
  </div>
548
  """
549
  else:
550
+ # Full version: complete bertviz model_view visualization (shows all layers)
551
+ html_result = model_view(attentions, tokens, html_action='return')
 
552
  return html_result.data if hasattr(html_result, 'data') else str(html_result)
553
 
554
  except Exception as e:
555
+ import traceback
556
+ traceback.print_exc()
557
  return f"<p>Error generating visualization: {str(e)}</p>"