cdpearlman commited on
Commit
7fa8fb4
·
1 Parent(s): d86e476

Attention refactor, better categorization and explanation

Browse files
app.py CHANGED
@@ -18,7 +18,7 @@ import json
18
  import torch
19
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
20
  perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
21
- from utils.head_detection import categorize_all_heads
22
  from utils.model_config import get_auto_selections
23
  from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
24
 
@@ -576,10 +576,11 @@ def update_pipeline_content(activation_data, model_name):
576
  except:
577
  pass
578
 
579
- # Agent G: Get full head categorization for attention stage UI (expandable categories)
580
  head_categories = None
581
  try:
582
- head_categories = categorize_all_heads(activation_data)
 
583
  except:
584
  pass
585
 
 
18
  import torch
19
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
20
  perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
21
+ from utils.head_detection import get_active_head_summary
22
  from utils.model_config import get_auto_selections
23
  from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
24
 
 
576
  except:
577
  pass
578
 
579
+ # Get head categorization from pre-computed JSON + runtime verification
580
  head_categories = None
581
  try:
582
+ from utils.head_detection import get_active_head_summary
583
+ head_categories = get_active_head_summary(activation_data, model_name)
584
  except:
585
  pass
586
 
components/pipeline.py CHANGED
@@ -400,133 +400,263 @@ def create_attention_content(attention_html=None, top_attended=None, layer_info=
400
  """
401
  Create content for the attention stage.
402
 
403
- Agent G: Removed "Most attended tokens" section (deprecated). Now shows head categorization
404
- to help users understand what different attention heads are doing.
405
 
406
  Args:
407
  attention_html: BertViz HTML string for attention visualization
408
- top_attended: DEPRECATED - no longer used, kept for backward compatibility
409
  layer_info: Optional layer information for context
410
- head_categories: Dict mapping category names to lists of head info dicts (from categorize_all_heads)
411
- Each head info has: {'layer': N, 'head': M, 'label': 'LN-HM', ...}
412
- Can also accept counts dict for backward compatibility.
413
  """
414
  content_items = [
415
  html.Div([
416
  html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
417
  html.P([
418
  "The model looks at ", html.Strong("all tokens at once"),
419
- " and figures out which ones are related to each other. This is called 'attention' - ",
420
  "each token 'attends to' other tokens to gather context for its prediction."
421
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '12px'}),
422
  html.P([
423
- "Attention has multiple ", html.Strong("heads"), " - each head learns to look for different types of relationships. ",
424
- "For example, one head might track subject-verb agreement, while another tracks pronouns and their referents."
425
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
426
  ])
427
  ]
428
 
429
- # Agent G: Head Categorization Summary with expandable categories
430
- if head_categories:
431
- category_labels = {
432
- 'previous_token': ('Previous-Token', '#667eea', 'Heads that attend to the immediately preceding token'),
433
- 'first_token': ('First/Positional', '#764ba2', 'Heads that focus on the first token or positional patterns'),
434
- 'bow': ('Bag-of-Words', '#f093fb', 'Heads with diffuse attention across many tokens'),
435
- 'syntactic': ('Syntactic', '#4facfe', 'Heads that capture grammatical relationships'),
436
- 'other': ('Other', '#6c757d', 'Heads with mixed or specialized patterns')
 
 
 
 
437
  }
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  category_sections = []
440
- for cat_key in ['previous_token', 'first_token', 'bow', 'syntactic', 'other']:
441
- cat_data = head_categories.get(cat_key, [])
 
 
 
 
442
 
443
- # Handle both list format (full data) and int format (counts only, backward compat)
444
- if isinstance(cat_data, int):
445
- count = cat_data
446
- head_list = []
447
- else:
448
- count = len(cat_data) if cat_data else 0
449
- head_list = cat_data
 
450
 
451
- if count > 0 and cat_key in category_labels:
452
- label, color, tooltip = category_labels[cat_key]
453
-
454
- # Build head list display (only if we have full data)
455
- head_chips = []
456
- if head_list:
457
- for head_info in head_list:
458
- head_label = head_info.get('label', f"L{head_info.get('layer', '?')}-H{head_info.get('head', '?')}")
459
- head_chips.append(
460
- html.Span(head_label, style={
461
- 'display': 'inline-block',
462
- 'padding': '4px 8px',
463
- 'margin': '2px',
464
- 'backgroundColor': f'{color}15',
465
- 'border': f'1px solid {color}30',
466
- 'borderRadius': '4px',
467
- 'fontSize': '12px',
468
- 'fontFamily': 'monospace'
469
- })
470
- )
471
-
472
- # Create expandable section for this category
473
- category_sections.append(
474
- html.Details([
475
- html.Summary([
476
- html.Span(label, style={'fontWeight': '500', 'color': '#495057'}),
477
- html.Span(f" ({count})", style={'marginLeft': '4px', 'color': '#6c757d'})
478
- ], style={
479
- 'padding': '8px 12px',
480
- 'backgroundColor': f'{color}15',
481
- 'border': f'1px solid {color}30',
482
- 'borderRadius': '8px',
483
- 'cursor': 'pointer',
484
- 'userSelect': 'none',
485
- 'listStyle': 'none',
486
- 'display': 'flex',
487
- 'alignItems': 'center'
488
- }, title=tooltip),
489
- # Expanded content - list of heads
490
  html.Div([
491
- html.P(tooltip, style={
492
- 'color': '#6c757d',
493
- 'fontSize': '12px',
494
- 'marginBottom': '8px',
495
- 'fontStyle': 'italic'
496
- }),
497
- html.Div(head_chips if head_chips else [
498
- html.Span("Head details not available", style={'color': '#999', 'fontSize': '12px'})
 
 
 
 
499
  ], style={
500
- 'display': 'flex',
501
- 'flexWrap': 'wrap',
502
- 'gap': '4px'
 
 
 
 
503
  })
504
  ], style={
505
- 'padding': '12px',
506
- 'backgroundColor': '#fafbfc',
507
- 'borderRadius': '0 0 8px 8px',
508
- 'marginTop': '-1px',
509
- 'border': f'1px solid {color}30',
510
- 'borderTop': 'none'
511
  })
512
- ], style={'marginBottom': '8px'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  if category_sections:
 
 
 
 
 
 
 
 
 
 
516
  content_items.append(
517
  html.Div([
518
- html.H5("Attention Head Categories:", style={'color': '#495057', 'marginBottom': '12px'}),
519
  html.P([
520
- html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '6px'}),
521
- "Click each category to expand and see which heads belong to it."
522
  ], style={'color': '#6c757d', 'fontSize': '12px', 'marginBottom': '12px'}),
523
- html.Div(category_sections)
 
 
 
 
 
 
 
 
 
 
524
  ], style={'marginBottom': '16px'})
525
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
  # BertViz visualization with navigation instructions
528
  if attention_html:
529
- # Agent G: Enhanced navigation instructions for head view
530
  content_items.append(
531
  html.Div([
532
  html.H5("How to Navigate the Attention Visualization:", style={'color': '#495057', 'marginBottom': '12px'}),
@@ -537,7 +667,6 @@ def create_attention_content(attention_html=None, top_attended=None, layer_info=
537
  html.Span("Click on layer/head numbers at the top to view specific attention heads.",
538
  style={'color': '#6c757d'})
539
  ], style={'marginBottom': '4px'}),
540
- # Sub-points for click behaviors
541
  html.Div([
542
  html.Span("• ", style={'color': '#f093fb', 'fontWeight': 'bold'}),
543
  html.Strong("Single click ", style={'color': '#495057'}),
 
400
  """
401
  Create content for the attention stage.
402
 
403
+ Displays head categorization with active/inactive states, activation bars,
404
+ suggested prompts, and guided interpretation.
405
 
406
  Args:
407
  attention_html: BertViz HTML string for attention visualization
408
+ top_attended: DEPRECATED - no longer used
409
  layer_info: Optional layer information for context
410
+ head_categories: Output from get_active_head_summary() dict with 'categories' key
411
+ containing per-category data with activation scores.
412
+ Falls back gracefully if None or old format.
413
  """
414
  content_items = [
415
  html.Div([
416
  html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
417
  html.P([
418
  "The model looks at ", html.Strong("all tokens at once"),
419
+ " and figures out which ones are related to each other. This is called 'attention' ",
420
  "each token 'attends to' other tokens to gather context for its prediction."
421
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '12px'}),
422
  html.P([
423
+ "Attention has multiple ", html.Strong("heads"), " each head learns to look for different types of relationships. ",
424
+ "Below you can see what role each head plays and whether it's active on your current input."
425
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
426
  ])
427
  ]
428
 
429
+ # New: Head Roles Panel using get_active_head_summary() output
430
+ if head_categories and isinstance(head_categories, dict) and 'categories' in head_categories:
431
+ categories = head_categories['categories']
432
+
433
+ # Color scheme per category
434
+ category_colors = {
435
+ 'previous_token': '#667eea',
436
+ 'induction': '#e67e22',
437
+ 'duplicate_token': '#9b59b6',
438
+ 'positional': '#2ecc71',
439
+ 'diffuse': '#3498db',
440
+ 'other': '#95a5a6'
441
  }
442
 
443
+ # Find the top recommended head for guided interpretation
444
+ guided_head = None
445
+ guided_cat = None
446
+ for cat_key in ['previous_token', 'induction', 'positional']:
447
+ cat_data = categories.get(cat_key, {})
448
+ heads = cat_data.get('heads', [])
449
+ active_heads = [h for h in heads if h.get('is_active')]
450
+ if active_heads:
451
+ best = max(active_heads, key=lambda h: h['activation_score'])
452
+ if guided_head is None or best['activation_score'] > guided_head['activation_score']:
453
+ guided_head = best
454
+ guided_cat = cat_data.get('display_name', cat_key)
455
+
456
+ # Guided interpretation recommendation
457
+ if guided_head:
458
+ content_items.append(
459
+ html.Div([
460
+ html.I(className='fas fa-lightbulb', style={'color': '#f39c12', 'marginRight': '8px', 'fontSize': '16px'}),
461
+ html.Span([
462
+ html.Strong("Try this: "),
463
+ f"Select Layer {guided_head['layer']}, Head {guided_head['head']} in the visualization below — ",
464
+ f"this is a {guided_cat} head ",
465
+ f"(activation: {guided_head['activation_score']:.0%} on your input)."
466
+ ], style={'color': '#495057', 'fontSize': '13px'})
467
+ ], style={
468
+ 'padding': '12px 16px', 'backgroundColor': '#fef9e7', 'borderRadius': '8px',
469
+ 'border': '1px solid #f9e79f', 'marginBottom': '16px', 'display': 'flex', 'alignItems': 'center'
470
+ })
471
+ )
472
+
473
+ # Build category sections
474
  category_sections = []
475
+ category_order = ['previous_token', 'induction', 'duplicate_token', 'positional', 'diffuse', 'other']
476
+
477
+ for cat_key in category_order:
478
+ cat_data = categories.get(cat_key, {})
479
+ if not cat_data:
480
+ continue
481
 
482
+ display_name = cat_data.get('display_name', cat_key)
483
+ description = cat_data.get('description', '')
484
+ educational_text = cat_data.get('educational_text', '')
485
+ icon_name = cat_data.get('icon', 'circle')
486
+ is_applicable = cat_data.get('is_applicable', True)
487
+ suggested_prompt = cat_data.get('suggested_prompt')
488
+ heads = cat_data.get('heads', [])
489
+ color = category_colors.get(cat_key, '#95a5a6')
490
 
491
+ # Active vs inactive indicator
492
+ has_active_heads = any(h.get('is_active') for h in heads)
493
+ status_icon = '●' if (is_applicable and has_active_heads) else '○'
494
+ status_color = color if (is_applicable and has_active_heads) else '#ccc'
495
+
496
+ # Skip "other" if no heads (which is the normal case)
497
+ if cat_key == 'other' and not heads:
498
+ continue
499
+
500
+ # Build head items with activation bars
501
+ head_items = []
502
+ if heads:
503
+ for head_info in heads:
504
+ activation = head_info.get('activation_score', 0.0)
505
+ is_active = head_info.get('is_active', False)
506
+ label = head_info.get('label', f"L{head_info['layer']}-H{head_info['head']}")
507
+
508
+ # Activation bar
509
+ bar_width = max(activation * 100, 2) # Min 2% for visibility
510
+ bar_color = color if is_active else '#ddd'
511
+
512
+ head_items.append(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  html.Div([
514
+ # Head label
515
+ html.Span(label, style={
516
+ 'fontFamily': 'monospace', 'fontSize': '12px', 'fontWeight': '500',
517
+ 'minWidth': '60px', 'color': '#495057' if is_active else '#aaa',
518
+ }, title=f"See Layer {head_info['layer']}, Head {head_info['head']} in the visualization below"),
519
+ # Activation bar
520
+ html.Div([
521
+ html.Div(style={
522
+ 'width': f'{bar_width}%', 'height': '100%',
523
+ 'backgroundColor': bar_color, 'borderRadius': '3px',
524
+ 'transition': 'width 0.3s ease'
525
+ })
526
  ], style={
527
+ 'flex': '1', 'height': '12px', 'backgroundColor': '#f0f0f0',
528
+ 'borderRadius': '3px', 'margin': '0 8px', 'overflow': 'hidden'
529
+ }),
530
+ # Score label
531
+ html.Span(f"{activation:.2f}", style={
532
+ 'fontSize': '11px', 'fontFamily': 'monospace',
533
+ 'color': '#495057' if is_active else '#bbb', 'minWidth': '32px'
534
  })
535
  ], style={
536
+ 'display': 'flex', 'alignItems': 'center', 'marginBottom': '4px',
537
+ 'opacity': '1' if is_active else '0.5'
 
 
 
 
538
  })
539
+ )
540
+
541
+ # Build the category section
542
+ # Header content
543
+ summary_children = [
544
+ html.Span(status_icon, style={
545
+ 'color': status_color, 'fontSize': '16px', 'marginRight': '8px'
546
+ }),
547
+ html.Span(display_name, style={'fontWeight': '500', 'color': '#495057'}),
548
+ ]
549
+
550
+ if heads:
551
+ active_count = sum(1 for h in heads if h.get('is_active'))
552
+ summary_children.append(
553
+ html.Span(f" ({active_count}/{len(heads)} active)", style={
554
+ 'marginLeft': '6px', 'color': '#6c757d', 'fontSize': '12px'
555
+ })
556
+ )
557
+
558
+ if not is_applicable:
559
+ summary_children.append(
560
+ html.Span(" — not triggered on this input", style={
561
+ 'marginLeft': '6px', 'color': '#aaa', 'fontSize': '12px', 'fontStyle': 'italic'
562
+ })
563
  )
564
+
565
+ # Expanded content
566
+ expanded_children = []
567
+
568
+ # Educational explanation
569
+ if educational_text:
570
+ expanded_children.append(
571
+ html.P(educational_text, style={
572
+ 'color': '#6c757d', 'fontSize': '13px', 'marginBottom': '10px',
573
+ 'fontStyle': 'italic', 'lineHeight': '1.5'
574
+ })
575
+ )
576
+
577
+ # Suggested prompt (for grayed-out categories)
578
+ if not is_applicable and suggested_prompt:
579
+ expanded_children.append(
580
+ html.Div([
581
+ html.I(className='fas fa-flask', style={'color': '#e67e22', 'marginRight': '6px'}),
582
+ html.Span(suggested_prompt, style={'color': '#e67e22', 'fontSize': '12px'})
583
+ ], style={
584
+ 'padding': '8px 12px', 'backgroundColor': '#fef5e7',
585
+ 'borderRadius': '6px', 'marginBottom': '10px', 'border': '1px solid #fde8c8'
586
+ })
587
+ )
588
+
589
+ # Head activation bars
590
+ if head_items:
591
+ expanded_children.append(html.Div(head_items))
592
+
593
+ category_sections.append(
594
+ html.Details([
595
+ html.Summary(summary_children, style={
596
+ 'padding': '10px 14px',
597
+ 'backgroundColor': f'{color}08' if is_applicable else '#fafafa',
598
+ 'border': f'1px solid {color}25' if is_applicable else '1px solid #eee',
599
+ 'borderRadius': '8px', 'cursor': 'pointer', 'userSelect': 'none',
600
+ 'listStyle': 'none', 'display': 'flex', 'alignItems': 'center'
601
+ }),
602
+ html.Div(expanded_children, style={
603
+ 'padding': '12px 14px', 'backgroundColor': '#fafbfc',
604
+ 'borderRadius': '0 0 8px 8px', 'marginTop': '-1px',
605
+ 'border': f'1px solid {color}25' if is_applicable else '1px solid #eee',
606
+ 'borderTop': 'none'
607
+ })
608
+ ], style={'marginBottom': '8px'}, open=(cat_key == 'previous_token')) # Default-open first category
609
+ )
610
 
611
  if category_sections:
612
+ # Legend
613
+ legend = html.Div([
614
+ html.Span("● = active on your input", style={
615
+ 'color': '#495057', 'fontSize': '11px', 'marginRight': '16px'
616
+ }),
617
+ html.Span("○ = role exists but not triggered", style={
618
+ 'color': '#aaa', 'fontSize': '11px'
619
+ })
620
+ ], style={'marginBottom': '10px'})
621
+
622
  content_items.append(
623
  html.Div([
624
+ html.H5("Attention Head Roles:", style={'color': '#495057', 'marginBottom': '8px'}),
625
  html.P([
626
+ "Each category represents a type of behavior we detected in this model's attention heads. ",
627
+ "Click a category to see individual heads and how strongly they're activated on your input."
628
  ], style={'color': '#6c757d', 'fontSize': '12px', 'marginBottom': '12px'}),
629
+ legend,
630
+ html.Div(category_sections),
631
+ # Accuracy caveat
632
+ html.Div([
633
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '6px', 'fontSize': '11px'}),
634
+ html.Span(
635
+ "These categories are simplified labels based on each head's dominant behavior. "
636
+ "In reality, heads can serve multiple roles and may behave differently on different inputs.",
637
+ style={'color': '#999', 'fontSize': '11px'}
638
+ )
639
+ ], style={'marginTop': '12px', 'padding': '8px 12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '6px'})
640
  ], style={'marginBottom': '16px'})
641
  )
642
+ elif head_categories is None:
643
+ # Model not analyzed — show fallback message
644
+ content_items.append(
645
+ html.Div([
646
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
647
+ html.Span(
648
+ "Head categorization is not available for this model. "
649
+ "The attention visualization below still shows the full attention patterns.",
650
+ style={'color': '#6c757d', 'fontSize': '13px'}
651
+ )
652
+ ], style={
653
+ 'padding': '12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
654
+ 'border': '1px solid #dee2e6', 'marginBottom': '16px'
655
+ })
656
+ )
657
 
658
  # BertViz visualization with navigation instructions
659
  if attention_html:
 
660
  content_items.append(
661
  html.Div([
662
  html.H5("How to Navigate the Attention Visualization:", style={'color': '#495057', 'marginBottom': '12px'}),
 
667
  html.Span("Click on layer/head numbers at the top to view specific attention heads.",
668
  style={'color': '#6c757d'})
669
  ], style={'marginBottom': '4px'}),
 
670
  html.Div([
671
  html.Span("• ", style={'color': '#f093fb', 'fontWeight': 'bold'}),
672
  html.Strong("Single click ", style={'color': '#495057'}),
jarvis_llmvis_ux_review.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLMVis UX & Explanation Review
2
+ **Date:** 2026-02-26
3
+ **Reviewer:** JARVIS
4
+ **Method:** Playwright automated walkthrough of https://cdpearlman-llmvis.hf.space (GPT-2 124M, prompt: "The cat sat on the mat. The cat")
5
+ **Reference:** `attention_handoff.md` (attention head categorization spec)
6
+
7
+ ---
8
+
9
+ ## Executive Summary
10
+
11
+ The app is in solid working shape. The pipeline storytelling is clean, the BertViz integration works, and attribution renders well. The two biggest gaps against the handoff spec are: (1) the attention head categorization is broken — 132/144 heads are mislabeled as "First/Positional," swamping all meaningful signal; and (2) the induction, duplicate, and diffuse head categories from the spec are entirely absent. Beyond that, the attention visualization is the weakest explanation panel — it shows the heatmap but doesn't teach the student what to look for. Ablation UX also has friction and never surfaced results in testing.
12
+
13
+ ---
14
+
15
+ ## 1. Overall Layout & First Impression
16
+
17
+ **What's good:**
18
+ - Clean gradient header, uncluttered layout
19
+ - The pipeline section ("How the Model Processes Your Input") is a strong pedagogical frame — the numbered steps with the flow chip bar (Input → Tokens → Embed → Attention → MLP → Output) is excellent
20
+ - Glossary modal auto-opens on first visit, which is a good onboarding move
21
+ - The sidebar module selection (showing `transformer.h.{N}.attn` etc.) is a nice power-user layer
22
+
23
+ **Issues:**
24
+ - **Glossary modal close button is off-screen** at default viewport widths. The `×` renders at x≈1858 on a 1400px window. Students on laptops will be stuck staring at a modal they can't close without scrolling right. Fix: position the close button inside the modal boundary, not at the document edge.
25
+ - **45-second cold start with no feedback.** After clicking Analyze, the pipeline stages show "Awaiting analysis..." with no progress indicator, spinner, or ETA. For a student, this looks broken. Fix: add a loading spinner or "Model is warming up (~30s)..." message on first run.
26
+ - **Generation Settings sliders are confusing.** "Number of Generation Choices" with values 1/3/5 is jargon. Students don't know what beam search is. The label should be "Explore How Many Different Continuations?" or similar, with a tooltip. The current glossary entry on Beam Search is good but isn't linked from the slider.
27
+
28
+ ---
29
+
30
+ ## 2. Tokenization Stage
31
+
32
+ **What's good:**
33
+ - Clean token→ID table. Exactly the right content.
34
+ - "Your text is split into 10 tokens" summary in the header is great.
35
+
36
+ **Issues:**
37
+ - **No visual "aha" moment.** The table shows Token→ID correctly, but doesn't show *why* "The" becomes 464 vs "the" becoming 262. The capitalization distinction (same word, different token) is sitting right there in this example and the app doesn't call it out. This is a perfect teachable moment — highlight it.
38
+ - **No subword tokenization example.** The prompt was simple English so all tokens were whole words. When a student types something with subwords (e.g., "transformers"), they won't know that's unusual. Consider adding a note: "Notice: some words may split into multiple pieces — try typing 'unhappiness' to see subword tokenization."
39
+ - **The token ID numbers mean nothing to students.** Worth a one-liner: "These IDs are just addresses in a vocabulary table of 50,257 words and word-pieces."
40
+
41
+ ---
42
+
43
+ ## 3. Embedding Stage
44
+
45
+ **What's good:**
46
+ - The `Token ID → Lookup Table → [768-dimensional vector]` flow diagram is clean and conceptually correct.
47
+ - The callout box ("How the lookup table was created: During training on billions of text examples...") is excellent — this is exactly the kind of "where did this come from?" context students need.
48
+
49
+ **Issues:**
50
+ - **No actual data shown.** The stage says "768-dimensional vector" but never shows a student what even 5 dimensions of that vector look like. Even a truncated display like `[0.23, -1.41, 0.07, ...]` would make it real.
51
+ - **No similarity demo.** The explanation says "words with similar meanings (like 'happy' and 'joyful') have similar vectors" — but doesn't show it. A small cosine similarity callout using tokens actually in the input ("'cat' and 'mat' are somewhat similar; 'cat' and 'The' are not") would land this point.
52
+ - **Missing: positional embeddings.** This is a significant omission. The embedding stage in a transformer is `token_embedding + positional_embedding`. The current explanation only covers token embeddings. Students who read further literature will be confused. Add: "Each token also gets a positional embedding added — a second vector encoding *where* in the sequence it appears."
53
+
54
+ ---
55
+
56
+ ## 4. Attention Stage
57
+
58
+ This is the most important and most underbuilt section. The handoff doc has a detailed vision that is only partially implemented.
59
+
60
+ ### 4a. Head Category Panel
61
+
62
+ **Critical bug: First/Positional is consuming 132/144 heads.**
63
+
64
+ The categorization output:
65
+ - Previous-Token: 6 heads ✓ (reasonable)
66
+ - First/Positional: **132 heads** ✗ (this is ~92% of all heads — clearly wrong)
67
+ - Syntactic: 5 heads (plausible)
68
+ - Other: 1 head
69
+
70
+ This makes the category panel meaningless. A student sees a wall of 132 head IDs under "First/Positional" and learns nothing. The classification threshold for positional heads is almost certainly too loose, OR the `all_scores` from the offline script are being compared against an incorrect threshold. The handoff spec calls for a cap of ~8 heads per category with layer diversity enforcement — that logic is either not implemented or the thresholds need significant tuning.
71
+
72
+ **Missing categories from the spec:**
73
+ The handoff doc specifies 6 categories:
74
+ 1. ✅ Previous Token (implemented)
75
+ 2. ❌ **Induction** (missing entirely)
76
+ 3. ❌ **Duplicate Token** (missing entirely)
77
+ 4. ✅ First/Positional (implemented but broken threshold)
78
+ 5. ❌ **Diffuse / Bag-of-Words** (missing entirely)
79
+ 6. ✅ Other/Unclassified (implemented)
80
+
81
+ "Syntactic" appears as a category but isn't in the handoff spec — unclear where it came from or how it's detected.
82
+
83
+ **Missing: runtime activation scoring.** The spec calls for each head to show an activation score on the *current input* (e.g., whether induction heads are firing given the repeated "The cat" in the prompt). Nothing like this exists yet — heads are just listed as belonging to categories with no indication of whether they're active or dormant on this specific input.
84
+
85
+ **Missing: greyed-out heads with "suggested prompts."** The spec's pedagogically most powerful idea — "Try adding a repeated sentence to see induction heads light up" — doesn't exist at all. This is the thing that turns passive observation into active discovery.
86
+
87
+ ### 4b. Attention Visualization (BertViz)
88
+
89
+ **What's good:**
90
+ - BertViz integration works and renders the attention heatmap
91
+ - The navigation instructions (single click, double click, hover) are clear
92
+
93
+ **Issues:**
94
+ - **No guided interpretation.** The visualization shows lines but doesn't tell the student what they're looking at. For a student who just read that "some heads track pronouns," they need a nudge: "Try Layer 4, Head 11 — this head often looks at the previous word." Right now the student opens a heatmap of spaghetti lines and has no idea what to conclude.
95
+ - **The attention viz and head category panel are disconnected.** Clicking a head in the category list should highlight/select it in the BertViz below. The handoff spec mentions this: "Clicking a head navigates to its attention heatmap." That linkage doesn't exist.
96
+ - **No explanation of what "good" attention looks like.** The viz shows all heads at once by default. For a 12×12 model that's 144 attention patterns — overwhelming. The default view should be a single interesting head (e.g., the strongest previous-token head), not all heads.
97
+ - **Layer selector is bare.** The "Layer: [dropdown]" control has no context. Why would a student change the layer? Add: "Earlier layers tend to capture syntax; later layers capture meaning."
98
+
99
+ ---
100
+
101
+ ## 5. MLP (Feed-Forward) Stage
102
+
103
+ **What's good:**
104
+ - The `768d → 3072d → 768d` expand/compress diagram is clean
105
+ - The "Why expand then compress?" callout box is excellent — the neuron activation framing is correct
106
+ - "This happens in each of the model's 12 layers, with attention and MLP working together" is a good summary
107
+
108
+ **Issues:**
109
+ - **No connection to the current input.** The Paris/France example is generic and not connected to the actual prompt being analyzed. Consider: "For your prompt, the MLP layers are likely retrieving knowledge about common English sentence structures."
110
+ - **No visualization.** MLP is the only stage with purely static text and a diagram. Even a simple bar chart of "top activated neurons at layer X" would make this real. The handoff doc doesn't spec this out, but it's a gap.
111
+ - **Missing: the residual stream framing.** The glossary defines "Residual Stream" but the MLP stage doesn't mention that the MLP *adds* to the residual stream rather than replacing it. This is fundamental to why the model can accumulate knowledge across layers.
112
+
113
+ ---
114
+
115
+ ## 6. Output Selection Stage
116
+
117
+ **What's good:**
118
+ - Top-5 next-token predictions with probability bars is exactly right
119
+ - The full-sentence context display with highlighted predicted token is excellent UX
120
+ - The "Note on Token Selection" callout about Beam Search and MoE is appropriately nuanced
121
+
122
+ **Issues:**
123
+ - **"13.5% confidence" framing is misleading.** "Confidence" implies certainty; this is a softmax probability, which is better described as "the model assigned a 13.5% probability to 'was' as the next word." Students may misread this as "the model is 13.5% confident it's right."
124
+ - **No contrast with wrong predictions.** The chart shows top-5 but doesn't explain *why* the model predicted "was" over "sat." A connection back to attribution ("The token 'cat' had the highest influence on predicting 'was'") would close the loop.
125
+ - **The token slider is unclear.** "Step through generated tokens" with a slider defaulting to 0 and showing "was" is confusing — it looks like there's nothing to step through. Label it: "Generated token 1 of 1: was" and grey out or hide the slider when only 1 token was generated.
126
+
127
+ ---
128
+
129
+ ## 7. Token Attribution Panel
130
+
131
+ **What's good:**
132
+ - The visualization works well — darker tokens = more important is intuitive
133
+ - The bar chart with normalized attribution scores is clean
134
+ - Results matched expectations: "was" (the second "cat" token, position 9) scored 1.0, "The" scored 0.87 — sensible given the prompt structure
135
+
136
+ **Issues:**
137
+ - **"Simple Gradient" is selected by default, not "Integrated Gradients."** The UI labels Simple Gradient as "faster, less accurate" and Integrated Gradients as "more accurate, slower" — but defaults to the less accurate one. For an educational tool where accuracy matters more than speed, this should be reversed. Or at minimum, note: "For learning purposes, Integrated Gradients gives more reliable results."
138
+ - **No explanation of what attribution scores mean in plain English.** The callout says "Tokens with higher attribution scores contributed more to the model's prediction" — but students need: "The second 'cat' scored highest because the model is pattern-matching 'The cat...' to predict what typically follows 'The cat' in English text."
139
+ - **No visual connection to the actual attention visualization.** If "was" had high attribution from "cat," students should be able to click through to see which attention heads facilitated that. Right now attribution and attention are completely siloed.
140
+ - **Target Token dropdown is confusing.** "Use top predicted token (default)" is fine, but the empty text box below it with "Leave empty to compute attribution for the top predicted token" is redundant and confusing — why show a text box that you immediately tell them not to fill?
141
+
142
+ ---
143
+
144
+ ## 8. Ablation Panel
145
+
146
+ **Issues (mostly UX):**
147
+ - **Ablation didn't show results in automated testing** — the head selection reset when switching tabs, suggesting state management issues between the Ablation and Attribution tabs.
148
+ - **No presets or suggestions.** The student faces a blank "Layer / Head" picker and has no idea which heads are interesting to ablate. The category panel above already identified previous-token heads (L4-H11, etc.) — there should be a "Try ablating this head" link from the category panel directly into the ablation form.
149
+ - **"Run Ablation Experiment" is permanently greyed out** until a head is added. The disabled state has no tooltip explaining why. Add: "Add at least one head above to run the experiment."
150
+ - **No explanation of what to expect.** Before running, tell students: "If this head is important, the top prediction may change. If it doesn't change, the head wasn't critical for this input."
151
+ - **No result interpretation.** After running (when it works), the diff between original and ablated predictions needs plain-English interpretation: "Removing L4-H11 changed 'was' (13.5%) → 'sat' (18.2%). This suggests that head was suppressing 'sat' as a prediction."
152
+
153
+ ---
154
+
155
+ ## 9. Sidebar
156
+
157
+ **What's good:**
158
+ - The "Model loaded successfully! Detected family: GPT-2 architecture" green badge is good UX
159
+ - Module selection dropdowns (Attention Modules, Layer Blocks, Normalization Parameters) make sense for power users
160
+
161
+ **Issues:**
162
+ - **Sidebar purpose is unclear to students.** There's no explanation of what changing "Attention Modules" does or why a student would want to. This entire panel reads like a developer debug tool that was left exposed.
163
+ - **"Clear Selections" does what, exactly?** No tooltip.
164
+ - Consider: either hide the sidebar behind an "Advanced" toggle for student mode, or add inline documentation for each control.
165
+
166
+ ---
167
+
168
+ ## 10. Chatbot (Robot Icon)
169
+
170
+ The robot icon is visible at bottom-right but the chat panel contents weren't captured in automated testing (JS error prevented inspection). Recommend manual review of the chatbot's response quality and whether it contextualizes responses to the current model/prompt state.
171
+
172
+ ---
173
+
174
+ ## Priority Recommendations for Cursor
175
+
176
+ ### 🔴 Critical (do these first)
177
+
178
+ 1. **Fix attention head categorization thresholds.** First/Positional capturing 132/144 heads makes the entire category panel meaningless. Tighten the threshold, enforce the ~8-head cap per category from the spec, and add layer diversity. This is the highest-impact fix.
179
+
180
+ 2. **Add the missing head categories.** Induction, Duplicate Token, and Diffuse are all specced in `attention_handoff.md` with detection logic. They need to be implemented. Induction is especially important for this exact prompt (repeated "The cat").
181
+
182
+ 3. **Fix the modal close button off-screen bug.** Students can't close the glossary modal on standard laptop viewports. Easy CSS fix: `position: absolute; right: 16px` inside the modal container, not the document.
183
+
184
+ 4. **Add a loading state after clicking Analyze.** 45 seconds of static "Awaiting analysis..." with no spinner is a UX failure. Add a pulsing animation or "Loading model..." progress message.
185
+
186
+ ### 🟡 High Priority
187
+
188
+ 5. **Connect head categories to the BertViz visualization.** Clicking a head ID (e.g., L4-H11) in the category panel should auto-select that head in the attention viz below.
189
+
190
+ 6. **Add runtime activation scoring to head categories.** Per the spec: show whether each head type is active on the current input. Gray out induction heads if there's no repetition in the input, with a "Try: 'The cat sat. The cat'" suggested prompt.
191
+
192
+ 7. **Add positional embeddings to the Embedding stage explanation.** Currently missing an entire half of what embeddings are.
193
+
194
+ 8. **Fix ablation state management.** Head selections shouldn't reset when switching between Ablation and Attribution tabs.
195
+
196
+ 9. **Change attribution default to Integrated Gradients.** It's the more accurate method; this is an educational tool, not a speed benchmark.
197
+
198
+ 10. **Capitalize on the tokenization "aha" moment.** "The" (464) vs "the" (262) is sitting right there in the example. Call it out explicitly.
199
+
200
+ ### 🟢 Enhancements
201
+
202
+ 11. **Add guided "what to look for" text to the attention visualization.** Pick one interesting head per model (pre-annotated) and surface it as a recommendation: "Try Layer 4, Head 11 to see a previous-token head in action."
203
+
204
+ 12. **Add suggested prompts for exploring each head category.** "To see induction heads activate, try: 'The cat sat on the mat. The cat...'"
205
+
206
+ 13. **Reframe "confidence" in Output stage.** Replace with "probability" throughout.
207
+
208
+ 14. **Link attribution results to attention heads.** "The token 'cat' was most influential — see which heads connected it to the prediction in the Attention stage."
209
+
210
+ 15. **Fix the Output stage token slider** — hide or disable it when only 1 token was generated.
211
+
212
+ 16. **Add a brief "what would you like to explore?" prompt to the ablation UI** with pre-suggested heads from the category panel.
213
+
214
+ 17. **Sidebar: add explanatory text** for what Module Selection controls, or hide it in an "Advanced" section.
215
+
216
+ ---
217
+
218
+ ## What's Already Strong (Don't Break)
219
+
220
+ - The 5-stage pipeline structure and the flow chip bar — keep it exactly as is
221
+ - The BertViz integration — it works and the navigation instructions are clear
222
+ - The callout boxes in Embedding and MLP — these are the best explanation text in the app
223
+ - The token attribution visualization (darker = more important) — intuitive and correct
224
+ - The top-5 output prediction chart — exactly the right content
225
+ - The glossary modal content — all 8 entries are well-written
226
+
227
+ ---
228
+
229
+ ## Comparison to Handoff Spec
230
+
231
+ | Spec Feature | Status |
232
+ |---|---|
233
+ | 6 head categories (Previous Token, Induction, Duplicate, Positional, Diffuse, Other) | ⚠️ Partial — 3/6 missing, Positional broken |
234
+ | Per-head activation scores on current input | ❌ Not implemented |
235
+ | Active/inactive state display (filled vs open circle) | ❌ Not implemented |
236
+ | Greyed-out heads with suggested prompts | ❌ Not implemented |
237
+ | Click head → navigate to attention heatmap | ❌ Not implemented |
238
+ | Runtime verification module | ❌ Not implemented |
239
+ | One-time offline analysis script | ✅ Appears to have run (JSON exists) |
240
+ | Educational tooltips per category | ⚠️ Partial — descriptions exist but brief |
rag_docs/head_categories_explained.md CHANGED
@@ -1,56 +1,58 @@
1
- # Attention Head Categories Explained
2
 
3
- ## What Are Head Categories?
4
 
5
- The dashboard automatically analyzes all attention heads in the model and categorizes them based on their behavior patterns. This helps you understand what each head is doing without having to inspect every attention map manually.
6
 
7
- Head categories appear in **Stage 3 (Attention)** of the pipeline. Click any category to expand it and see which specific heads (like L0-H3, L2-H11) belong to it.
 
8
 
9
- ## The Five Categories
10
 
11
- ### Previous-Token Heads
12
 
13
- **What they do**: These heads strongly attend to the **immediately preceding token**. For every token at position *i*, the head focuses most of its attention on position *i-1*.
 
14
 
15
- **Why they matter**: Previous-token heads help the model track local context -- the word that just came before. They're important for bigram patterns (common two-word combinations like "of the" or "in a").
16
 
17
- **Detection**: A head is classified as Previous-Token if, on average, more than 40% of each token's attention goes to the token directly before it.
18
 
19
- **In the dashboard**: These heads are labeled with a purple color. Ablating them often causes noticeable changes in predictions.
20
 
21
- ### First/Positional Heads
 
22
 
23
- **What they do**: These heads focus heavily on the **first token** in the sequence or show strong **positional patterns** (always attending to a specific position regardless of content).
24
 
25
- **Why they matter**: The first token often serves as a "default" attention target. Positional heads help the model keep track of where it is in the sequence.
26
 
27
- **Detection**: Classified when average attention to the first token exceeds 25%.
28
 
29
- ### Bag-of-Words (BoW) Heads
 
30
 
31
- **What they do**: These heads spread their attention **broadly and evenly** across many tokens, without focusing strongly on any particular one.
32
 
33
- **Why they matter**: BoW heads capture a general summary of the entire input. They help the model maintain an overall sense of what the text is about.
34
 
35
- **Detection**: Classified when the attention distribution has high entropy (≥ 0.65 normalized) and no single token receives more than 35% attention.
 
36
 
37
- ### Syntactic Heads
38
 
39
- **What they do**: These heads attend to tokens at **consistent distances**, suggesting they track grammatical or structural relationships (like subject-verb pairs).
40
 
41
- **Why they matter**: Syntactic heads help the model understand grammar and sentence structure. They might connect a verb to its subject or a pronoun to what it refers to.
42
 
43
- **Detection**: Classified when tokens consistently attend to other tokens at similar distances, with low variance in attention distances.
44
 
45
- ### Other
46
 
47
- **What they do**: Heads that don't clearly fit any of the above patterns. They may have mixed or context-dependent behavior.
 
 
48
 
49
- **Why they matter**: "Other" doesn't mean unimportant. These heads may serve specialized roles that only activate for certain inputs. They're worth investigating through ablation experiments.
50
 
51
- ## Using Categories for Experiments
52
-
53
- Head categories are especially useful for guiding ablation experiments:
54
- - Ablate a **Previous-Token** head to see if local context patterns break
55
- - Ablate a **BoW** head to see if the model loses global context
56
- - Compare the effect of ablating heads from different categories on the same prompt
 
1
+ # Attention Head Categories
2
 
3
+ This document explains the different types of attention heads found in transformer models. These categories are determined through **offline analysis** using TransformerLens and **verified at runtime** against your actual input.
4
 
5
+ ## Categories
6
 
7
+ ### Previous Token
8
+ **Symbol:** ● (active on most inputs)
9
 
10
+ Attends to the immediately preceding token — like reading left to right. This head helps the model track local word-by-word patterns. It's one of the most common and reliable head types.
11
 
12
+ **What to look for in the visualization:** Strong diagonal line one position below the main diagonal.
13
 
14
+ ### Induction
15
+ **Symbol:** ● when repeated tokens exist, ○ otherwise
16
 
17
+ Completes repeated patterns: if the model saw [A][B] before and now sees [A], it predicts [B] will follow. This is one of the most important mechanisms in transformer language models.
18
 
19
+ **Requires:** Repeated tokens in your input. If no tokens repeat, this category appears grayed out.
20
 
21
+ **Try this prompt:** "The cat sat on the mat. The cat" the repeated "The cat" activates induction heads.
22
 
23
+ ### Duplicate Token
24
+ **Symbol:** ● when duplicate tokens exist, ○ otherwise
25
 
26
+ Notices when the same word appears more than once, acting like a highlighter for repeated words. Helps the model track which words have already been said.
27
 
28
+ **Requires:** Repeated tokens in your input.
29
 
30
+ **Try this prompt:** "The cat sat. The cat slept." — the repeated words activate duplicate-token heads.
31
 
32
+ ### Positional / First-Token
33
+ **Symbol:** ● (active on most inputs)
34
 
35
+ Always pays attention to the very first word, using it as a fixed anchor point. The first token often serves as a "default" position when no specific token is relevant.
36
 
37
+ **What to look for:** Strong vertical line at column 0 (all tokens attending to position 0).
38
 
39
+ ### Diffuse / Spread
40
+ **Symbol:** ● (active on most inputs)
41
 
42
+ Spreads attention evenly across many words, gathering general context rather than focusing on one spot. Provides a "big picture" summary of the input.
43
 
44
+ **What to look for:** No strong patterns attention is spread roughly evenly across all tokens.
45
 
46
+ ### Other / Unclassified
47
 
48
+ Heads whose dominant pattern doesn't fit the categories above. These may perform more complex or context-dependent operations.
49
 
50
+ ## How It Works
51
 
52
+ 1. **Offline Analysis:** A TransformerLens script analyzes each head across many test inputs and assigns categories based on dominant behavior patterns.
53
+ 2. **Runtime Verification:** When you enter a prompt, the app checks whether each head's known role is actually active on your specific input.
54
+ 3. **Active vs Inactive:** A filled circle (●) means the head's role is triggered. An open circle (○) means the role exists but isn't triggered on your current input (e.g., no repeated tokens for induction).
55
 
56
+ ## Important Note
57
 
58
+ These categories are simplified labels based on each head's dominant behavior pattern. In reality, attention heads can serve multiple roles and may behave differently depending on the input.
 
 
 
 
 
scripts/analyze_heads.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Offline Head Analysis Script
4
+
5
+ Uses TransformerLens to analyze attention head behaviors across test inputs
6
+ and generates a JSON file with head categories for each model.
7
+
8
+ Usage:
9
+ python scripts/analyze_heads.py --model gpt2
10
+ python scripts/analyze_heads.py --model gpt2 gpt2-medium EleutherAI/pythia-70m
11
+ python scripts/analyze_heads.py --all
12
+
13
+ Output:
14
+ Writes to utils/head_categories.json
15
+ """
16
+
17
+ import os
18
+ os.environ["USE_TF"] = "0" # Prevent TensorFlow noise
19
+
20
+ import argparse
21
+ import json
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+ from typing import Dict, List, Any, Tuple
26
+
27
+ import torch
28
+ import numpy as np
29
+
30
+ # Add project root to path
31
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
32
+ sys.path.insert(0, str(PROJECT_ROOT))
33
+
34
+ JSON_OUTPUT_PATH = PROJECT_ROOT / "utils" / "head_categories.json"
35
+
36
+ # ============================================================================
37
+ # TransformerLens model name mapping
38
+ # ============================================================================
39
+ # TL uses its own naming conventions. Map from HuggingFace names
40
+ # (used in our model_config.py) to TL names.
41
+
42
+ HF_TO_TL_NAME = {
43
+ "gpt2": "gpt2-small",
44
+ "openai-community/gpt2": "gpt2-small",
45
+ "gpt2-medium": "gpt2-medium",
46
+ "openai-community/gpt2-medium": "gpt2-medium",
47
+ "gpt2-large": "gpt2-large",
48
+ "openai-community/gpt2-large": "gpt2-large",
49
+ "gpt2-xl": "gpt2-xl",
50
+ "openai-community/gpt2-xl": "gpt2-xl",
51
+ "EleutherAI/pythia-70m": "pythia-70m",
52
+ "EleutherAI/pythia-160m": "pythia-160m",
53
+ "EleutherAI/pythia-410m": "pythia-410m",
54
+ "EleutherAI/pythia-1b": "pythia-1b",
55
+ "EleutherAI/pythia-1.4b": "pythia-1.4b",
56
+ "facebook/opt-125m": "opt-125m",
57
+ "facebook/opt-350m": "opt-350m",
58
+ "facebook/opt-1.3b": "opt-1.3b",
59
+ }
60
+
61
+ # Default models to analyze
62
+ DEFAULT_MODELS = ["gpt2"]
63
+
64
+ ALL_PRIORITY_MODELS = [
65
+ "gpt2",
66
+ "gpt2-medium",
67
+ "EleutherAI/pythia-70m",
68
+ "EleutherAI/pythia-160m",
69
+ "EleutherAI/pythia-410m",
70
+ "facebook/opt-125m",
71
+ ]
72
+
73
+ # ============================================================================
74
+ # Category metadata (shared across all models)
75
+ # ============================================================================
76
+
77
+ CATEGORY_METADATA = {
78
+ "previous_token": {
79
+ "display_name": "Previous Token",
80
+ "description": "Attends to the immediately preceding token — like reading left to right",
81
+ "icon": "arrow-left",
82
+ "educational_text": "This head looks at the word right before the current one. Like reading left to right, it helps track local word-by-word patterns.",
83
+ "requires_repetition": False,
84
+ },
85
+ "induction": {
86
+ "display_name": "Induction",
87
+ "description": "Completes repeated patterns: if it saw [A][B] before and now sees [A], it predicts [B]",
88
+ "icon": "repeat",
89
+ "educational_text": "This head finds patterns that happened before and predicts they'll happen again. If it saw 'the cat' earlier, it expects the same words to follow.",
90
+ "requires_repetition": True,
91
+ "suggested_prompt": "Try: 'The cat sat on the mat. The cat' — the repeated 'The cat' lets induction heads activate.",
92
+ },
93
+ "duplicate_token": {
94
+ "display_name": "Duplicate Token",
95
+ "description": "Notices when the same word appears more than once",
96
+ "icon": "clone",
97
+ "educational_text": "This head notices when the same word appears more than once, like a highlighter for repeated words. It helps the model track which words have already been said.",
98
+ "requires_repetition": True,
99
+ "suggested_prompt": "Try a prompt with repeated words like 'The cat sat. The cat slept.' to see duplicate-token heads light up.",
100
+ },
101
+ "positional": {
102
+ "display_name": "Positional / First-Token",
103
+ "description": "Always pays attention to the very first word, using it as an anchor point",
104
+ "icon": "map-pin",
105
+ "educational_text": "This head always pays attention to the very first word, using it as an anchor point. The first token serves as a 'default' position when no other token is specifically relevant.",
106
+ "requires_repetition": False,
107
+ },
108
+ "diffuse": {
109
+ "display_name": "Diffuse / Spread",
110
+ "description": "Spreads attention evenly across many words, gathering general context",
111
+ "icon": "expand-arrows-alt",
112
+ "educational_text": "This head spreads its attention evenly across many words, gathering general context rather than focusing on one spot. It provides a 'big picture' summary of the input.",
113
+ "requires_repetition": False,
114
+ },
115
+ }
116
+
117
+
118
+ # ============================================================================
119
+ # Test input generation
120
+ # ============================================================================
121
+
122
+ def generate_test_inputs(tokenizer) -> Dict[str, List[str]]:
123
+ """Generate categorized test inputs for head analysis."""
124
+
125
+ # Natural language prompts for general analysis
126
+ natural_prompts = [
127
+ "The quick brown fox jumps over the lazy dog.",
128
+ "In the beginning, there was nothing but darkness and silence.",
129
+ "Machine learning models process data to make predictions about the future.",
130
+ "She walked through the park and noticed the flowers blooming everywhere.",
131
+ "The president announced new economic policies at the press conference today.",
132
+ "After years of research, scientists finally discovered the missing link.",
133
+ "The library was quiet except for the occasional turning of pages.",
134
+ "Programming is both an art and a science requiring careful thought.",
135
+ "The restaurant on the corner served the best pizza in the entire city.",
136
+ "Education is the most powerful tool for changing the world around us.",
137
+ "The storm clouds gathered on the horizon as the wind began to howl.",
138
+ "Mathematics provides the foundation for understanding complex physical phenomena.",
139
+ "The children played happily in the garden while their parents watched.",
140
+ "Economic growth depends on innovation, investment, and human capital development.",
141
+ "The old man sat on the bench and watched the pigeons gather crumbs.",
142
+ "Artificial intelligence will transform every industry in the coming decades.",
143
+ "The river flowed gently through the valley between the tall mountains.",
144
+ "Good communication skills are essential for success in any professional career.",
145
+ "The concert hall was packed with enthusiastic fans waiting for the show.",
146
+ "Climate change poses significant challenges for agriculture and food security.",
147
+ ]
148
+
149
+ # Repetitive prompts for induction / duplicate detection
150
+ repetitive_prompts = [
151
+ "The cat sat on the mat. The cat sat on the mat.",
152
+ "One two three four five. One two three four five.",
153
+ "Hello world hello world hello world hello world.",
154
+ "Alice went to the store. Bob went to the store. Alice went to the store.",
155
+ "The dog chased the ball. The dog chased the ball. The dog chased.",
156
+ "Red blue green red blue green red blue green red.",
157
+ "I like apples and I like oranges and I like apples.",
158
+ "The sun rises in the east. The sun sets in the west. The sun rises.",
159
+ "Monday Tuesday Wednesday Monday Tuesday Wednesday Monday.",
160
+ "She said hello and he said hello and she said hello again.",
161
+ "The key to success is practice. The key to success is patience.",
162
+ "We went to the park and then we went to the park again.",
163
+ "First second third first second third first second third.",
164
+ "The teacher asked the student. The student asked the teacher. The teacher asked.",
165
+ "North south east west north south east west north south.",
166
+ "Open the door. Close the door. Open the door. Close the door.",
167
+ "The big red ball bounced. The big red ball rolled.",
168
+ "Cat dog cat dog cat dog cat dog cat dog.",
169
+ "Learn practice improve learn practice improve learn practice.",
170
+ "The man walked. The woman walked. The man walked. The woman walked.",
171
+ ]
172
+
173
+ return {
174
+ "natural": natural_prompts,
175
+ "repetitive": repetitive_prompts,
176
+ }
177
+
178
+
179
+ # ============================================================================
180
+ # Head scoring functions
181
+ # ============================================================================
182
+
183
+ def score_previous_token(attn_patterns: torch.Tensor) -> torch.Tensor:
184
+ """
185
+ Score each head for previous-token behavior.
186
+
187
+ For each position i > 0, check attention to position i-1.
188
+ Returns [n_layers, n_heads] scores.
189
+ """
190
+ n_layers, n_heads, seq_len, _ = attn_patterns.shape
191
+
192
+ if seq_len < 2:
193
+ return torch.zeros(n_layers, n_heads)
194
+
195
+ scores = torch.zeros(n_layers, n_heads)
196
+ for i in range(1, seq_len):
197
+ scores += attn_patterns[:, :, i, i - 1]
198
+ scores /= (seq_len - 1)
199
+
200
+ return scores
201
+
202
+
203
+ def score_positional(attn_patterns: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ Score each head for first-token / positional behavior.
206
+
207
+ Measures mean attention to position 0 across all positions.
208
+ Returns [n_layers, n_heads] scores.
209
+ """
210
+ # Mean of column 0 across all query positions
211
+ return attn_patterns[:, :, :, 0].mean(dim=-1)
212
+
213
+
214
+ def score_diffuse(attn_patterns: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ Score each head for diffuse / bag-of-words behavior.
217
+
218
+ Measures normalized entropy of attention distribution.
219
+ Returns [n_layers, n_heads] scores.
220
+ """
221
+ n_layers, n_heads, seq_len, _ = attn_patterns.shape
222
+
223
+ epsilon = 1e-10
224
+ p = attn_patterns + epsilon
225
+ entropy = -torch.sum(p * torch.log(p), dim=-1) # [layers, heads, seq_len]
226
+ max_entropy = np.log(seq_len)
227
+ normalized = entropy / max_entropy if max_entropy > 0 else entropy
228
+
229
+ return normalized.mean(dim=-1) # Average over positions
230
+
231
+
232
+ def score_induction(attn_patterns: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
233
+ """
234
+ Score each head for induction behavior.
235
+
236
+ For repeated tokens: if token[i] == token[j] (j < i), check attention from i to j+1.
237
+ Returns [n_layers, n_heads] scores.
238
+ """
239
+ n_layers, n_heads, seq_len, _ = attn_patterns.shape
240
+ scores = torch.zeros(n_layers, n_heads)
241
+ count = 0
242
+
243
+ for i in range(2, seq_len):
244
+ for j in range(0, i - 1):
245
+ if tokens[i].item() == tokens[j].item():
246
+ target = j + 1
247
+ if target < seq_len:
248
+ scores += attn_patterns[:, :, i, target]
249
+ count += 1
250
+
251
+ if count > 0:
252
+ scores /= count
253
+
254
+ return scores
255
+
256
+
257
+ def score_duplicate_token(attn_patterns: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
258
+ """
259
+ Score each head for duplicate-token behavior.
260
+
261
+ For repeated tokens: check attention from later to earlier occurrence.
262
+ Returns [n_layers, n_heads] scores.
263
+ """
264
+ n_layers, n_heads, seq_len, _ = attn_patterns.shape
265
+ scores = torch.zeros(n_layers, n_heads)
266
+ count = 0
267
+
268
+ for i in range(1, seq_len):
269
+ for j in range(0, i):
270
+ if tokens[i].item() == tokens[j].item():
271
+ scores += attn_patterns[:, :, i, j]
272
+ count += 1
273
+
274
+ if count > 0:
275
+ scores /= count
276
+
277
+ return scores
278
+
279
+
280
+ # ============================================================================
281
+ # Main analysis
282
+ # ============================================================================
283
+
284
+ def analyze_model(model_name: str, device: str = "cpu") -> Dict[str, Any]:
285
+ """
286
+ Run full head analysis for a model.
287
+
288
+ Returns a dict ready for JSON serialization.
289
+ """
290
+ from transformer_lens import HookedTransformer
291
+
292
+ tl_name = HF_TO_TL_NAME.get(model_name, model_name)
293
+ print(f"\n{'='*60}")
294
+ print(f"Analyzing: {model_name} (TL name: {tl_name})")
295
+ print(f"{'='*60}")
296
+
297
+ print("Loading model...")
298
+ model = HookedTransformer.from_pretrained(tl_name, device=device)
299
+
300
+ n_layers = model.cfg.n_layers
301
+ n_heads = model.cfg.n_heads
302
+ print(f" Layers: {n_layers}, Heads per layer: {n_heads}")
303
+
304
+ # Generate test inputs
305
+ test_inputs = generate_test_inputs(model.tokenizer)
306
+
307
+ # Accumulators for scores
308
+ prev_token_scores = torch.zeros(n_layers, n_heads)
309
+ positional_scores = torch.zeros(n_layers, n_heads)
310
+ diffuse_scores = torch.zeros(n_layers, n_heads)
311
+ induction_scores = torch.zeros(n_layers, n_heads)
312
+ duplicate_scores = torch.zeros(n_layers, n_heads)
313
+
314
+ natural_count = 0
315
+ repetitive_count = 0
316
+
317
+ # Analyze natural prompts (for prev_token, positional, diffuse)
318
+ print("\nAnalyzing natural prompts...")
319
+ for prompt in test_inputs["natural"]:
320
+ try:
321
+ tokens = model.to_tokens(prompt)
322
+ if tokens.shape[1] < 3:
323
+ continue
324
+
325
+ with torch.no_grad():
326
+ _, cache = model.run_with_cache(tokens)
327
+
328
+ # Stack attention patterns: [n_layers, n_heads, seq_len, seq_len]
329
+ attn_patterns = torch.stack([
330
+ cache["pattern", layer][0] # Remove batch dim
331
+ for layer in range(n_layers)
332
+ ])
333
+
334
+ prev_token_scores += score_previous_token(attn_patterns)
335
+ positional_scores += score_positional(attn_patterns)
336
+ diffuse_scores += score_diffuse(attn_patterns)
337
+ natural_count += 1
338
+
339
+ except Exception as e:
340
+ print(f" Warning: Skipped prompt: {e}")
341
+ continue
342
+
343
+ print(f" Processed {natural_count} natural prompts")
344
+
345
+ # Analyze repetitive prompts (for induction + duplicate)
346
+ print("Analyzing repetitive prompts...")
347
+ for prompt in test_inputs["repetitive"]:
348
+ try:
349
+ tokens = model.to_tokens(prompt)
350
+ if tokens.shape[1] < 4:
351
+ continue
352
+
353
+ with torch.no_grad():
354
+ _, cache = model.run_with_cache(tokens)
355
+
356
+ attn_patterns = torch.stack([
357
+ cache["pattern", layer][0]
358
+ for layer in range(n_layers)
359
+ ])
360
+
361
+ induction_scores += score_induction(attn_patterns, tokens[0])
362
+ duplicate_scores += score_duplicate_token(attn_patterns, tokens[0])
363
+
364
+ # Also accumulate general scores for these prompts
365
+ prev_token_scores += score_previous_token(attn_patterns)
366
+ positional_scores += score_positional(attn_patterns)
367
+ diffuse_scores += score_diffuse(attn_patterns)
368
+ natural_count += 1
369
+
370
+ repetitive_count += 1
371
+
372
+ except Exception as e:
373
+ print(f" Warning: Skipped prompt: {e}")
374
+ continue
375
+
376
+ print(f" Processed {repetitive_count} repetitive prompts")
377
+
378
+ # Average scores
379
+ if natural_count > 0:
380
+ prev_token_scores /= natural_count
381
+ positional_scores /= natural_count
382
+ diffuse_scores /= natural_count
383
+ if repetitive_count > 0:
384
+ induction_scores /= repetitive_count
385
+ duplicate_scores /= repetitive_count
386
+
387
+ # Select top heads per category
388
+ all_category_scores = {
389
+ "previous_token": prev_token_scores,
390
+ "induction": induction_scores,
391
+ "duplicate_token": duplicate_scores,
392
+ "positional": positional_scores,
393
+ "diffuse": diffuse_scores,
394
+ }
395
+
396
+ # Print score summaries
397
+ print("\nScore summaries (max per category):")
398
+ for cat_name, scores in all_category_scores.items():
399
+ max_score = scores.max().item()
400
+ max_idx = scores.argmax()
401
+ max_layer = max_idx // n_heads
402
+ max_head = max_idx % n_heads
403
+ print(f" {cat_name:20s}: max={max_score:.4f} at L{max_layer}-H{max_head}")
404
+
405
+ # Build category data
406
+ categories_data = {}
407
+
408
+ for cat_name, scores in all_category_scores.items():
409
+ top_heads = select_top_heads(scores, n_layers, n_heads, cat_name)
410
+
411
+ cat_entry = dict(CATEGORY_METADATA[cat_name])
412
+ cat_entry["top_heads"] = top_heads
413
+ categories_data[cat_name] = cat_entry
414
+
415
+ print(f"\n {cat_name} ({len(top_heads)} heads):")
416
+ for h in top_heads:
417
+ print(f" L{h['layer']}-H{h['head']}: {h['score']:.4f}")
418
+
419
+ # Build the full model entry
420
+ model_entry = {
421
+ "model_name": model_name,
422
+ "num_layers": n_layers,
423
+ "num_heads": n_heads,
424
+ "analysis_date": time.strftime("%Y-%m-%d"),
425
+ "categories": categories_data,
426
+ "all_scores": {
427
+ cat: scores.tolist()
428
+ for cat, scores in all_category_scores.items()
429
+ }
430
+ }
431
+
432
+ return model_entry
433
+
434
+
435
+ def select_top_heads(
436
+ scores: torch.Tensor,
437
+ n_layers: int,
438
+ n_heads: int,
439
+ category: str,
440
+ max_heads: int = 8,
441
+ primary_threshold: float = 0.25,
442
+ min_threshold: float = 0.10,
443
+ ) -> List[Dict[str, Any]]:
444
+ """
445
+ Select the top heads for a category, enforcing layer diversity.
446
+
447
+ Strategy:
448
+ 1. Take all heads above primary_threshold
449
+ 2. Ensure we include the best head from each layer above min_threshold
450
+ 3. Cap at max_heads, keeping highest scores
451
+ """
452
+ candidates = []
453
+
454
+ for layer in range(n_layers):
455
+ for head in range(n_heads):
456
+ score = scores[layer, head].item()
457
+ if score > min_threshold:
458
+ candidates.append({
459
+ "layer": layer,
460
+ "head": head,
461
+ "score": round(score, 4),
462
+ })
463
+
464
+ # Sort by score descending
465
+ candidates.sort(key=lambda x: x["score"], reverse=True)
466
+
467
+ # Select: prioritize above primary_threshold, then fill with layer diversity
468
+ selected = []
469
+ selected_keys = set()
470
+ layers_covered = set()
471
+
472
+ # First pass: take all above primary threshold
473
+ for c in candidates:
474
+ if c["score"] >= primary_threshold and len(selected) < max_heads:
475
+ key = (c["layer"], c["head"])
476
+ if key not in selected_keys:
477
+ selected.append(c)
478
+ selected_keys.add(key)
479
+ layers_covered.add(c["layer"])
480
+
481
+ # Second pass: ensure layer diversity (best from each uncovered layer)
482
+ for c in candidates:
483
+ if len(selected) >= max_heads:
484
+ break
485
+ if c["layer"] not in layers_covered:
486
+ key = (c["layer"], c["head"])
487
+ if key not in selected_keys:
488
+ selected.append(c)
489
+ selected_keys.add(key)
490
+ layers_covered.add(c["layer"])
491
+
492
+ # Sort final result by layer, then head
493
+ selected.sort(key=lambda x: (x["layer"], x["head"]))
494
+
495
+ return selected[:max_heads]
496
+
497
+
498
+ # ============================================================================
499
+ # CLI
500
+ # ============================================================================
501
+
502
+ def main():
503
+ parser = argparse.ArgumentParser(description="Analyze attention head categories using TransformerLens")
504
+ parser.add_argument("--model", nargs="+", default=None,
505
+ help="HuggingFace model name(s) to analyze (e.g., gpt2, EleutherAI/pythia-70m)")
506
+ parser.add_argument("--all", action="store_true",
507
+ help="Analyze all priority models")
508
+ parser.add_argument("--device", default="cpu",
509
+ help="Device to use (cpu or cuda)")
510
+ parser.add_argument("--output", type=str, default=None,
511
+ help="Output JSON path (default: utils/head_categories.json)")
512
+ args = parser.parse_args()
513
+
514
+ # Determine models to analyze
515
+ if args.all:
516
+ models = ALL_PRIORITY_MODELS
517
+ elif args.model:
518
+ models = args.model
519
+ else:
520
+ models = DEFAULT_MODELS
521
+
522
+ output_path = Path(args.output) if args.output else JSON_OUTPUT_PATH
523
+
524
+ # Load existing data if present
525
+ existing_data = {}
526
+ if output_path.exists():
527
+ try:
528
+ with open(output_path, 'r') as f:
529
+ existing_data = json.load(f)
530
+ print(f"Loaded existing data from {output_path} ({len(existing_data)} models)")
531
+ except (json.JSONDecodeError, IOError):
532
+ pass
533
+
534
+ # Analyze each model
535
+ for model_name in models:
536
+ try:
537
+ result = analyze_model(model_name, device=args.device)
538
+
539
+ # Store under the HuggingFace name
540
+ existing_data[model_name] = result
541
+
542
+ # Also store under the short name for lookup
543
+ short_name = model_name.split('/')[-1] if '/' in model_name else None
544
+ if short_name and short_name != model_name:
545
+ existing_data[short_name] = result
546
+
547
+ except Exception as e:
548
+ print(f"\nERROR analyzing {model_name}: {e}")
549
+ import traceback
550
+ traceback.print_exc()
551
+ continue
552
+
553
+ # Write output
554
+ output_path.parent.mkdir(parents=True, exist_ok=True)
555
+ with open(output_path, 'w') as f:
556
+ json.dump(existing_data, f, indent=2)
557
+
558
+ print(f"\n{'='*60}")
559
+ print(f"Done! Wrote {len(existing_data)} model entries to {output_path}")
560
+ print(f"{'='*60}")
561
+
562
+
563
+ if __name__ == "__main__":
564
+ main()
tests/conftest.py CHANGED
@@ -199,12 +199,4 @@ def mock_attribution_result():
199
  }
200
 
201
 
202
- # =============================================================================
203
- # Head Categorization Config
204
- # =============================================================================
205
 
206
- @pytest.fixture
207
- def default_head_config():
208
- """Default head categorization configuration for testing."""
209
- from utils.head_detection import HeadCategorizationConfig
210
- return HeadCategorizationConfig()
 
199
  }
200
 
201
 
 
 
 
202
 
 
 
 
 
 
tests/test_head_detection.py CHANGED
@@ -1,313 +1,431 @@
1
  """
2
  Tests for utils/head_detection.py
3
 
4
- Tests attention head categorization heuristics using synthetic attention matrices.
5
  """
6
 
7
  import pytest
8
  import torch
 
9
  import numpy as np
 
 
10
  from utils.head_detection import (
11
- compute_attention_entropy,
12
- detect_previous_token_head,
13
- detect_first_token_head,
14
- detect_bow_head,
15
- detect_syntactic_head,
16
- categorize_attention_head,
17
- categorize_all_heads,
18
- format_categorization_summary,
19
- HeadCategorizationConfig
20
  )
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class TestComputeAttentionEntropy:
24
- """Tests for compute_attention_entropy function."""
25
-
26
  def test_uniform_distribution_high_entropy(self):
27
- """Uniform attention should have high (near 1.0) normalized entropy."""
28
- # 4 positions with equal attention
29
- uniform = torch.tensor([0.25, 0.25, 0.25, 0.25])
30
- entropy = compute_attention_entropy(uniform)
31
-
32
- # Normalized entropy should be close to 1.0 for uniform
33
- assert 0.95 <= entropy <= 1.0, f"Expected ~1.0, got {entropy}"
34
-
35
  def test_peaked_distribution_low_entropy(self):
36
- """Peaked attention should have low normalized entropy."""
37
- # One position dominates
38
- peaked = torch.tensor([0.97, 0.01, 0.01, 0.01])
39
- entropy = compute_attention_entropy(peaked)
40
-
41
- # Should be low entropy
42
- assert entropy < 0.3, f"Expected low entropy, got {entropy}"
43
-
44
- def test_entropy_bounds(self):
45
- """Entropy should always be between 0 and 1 (normalized)."""
46
- test_cases = [
47
- torch.tensor([1.0, 0.0, 0.0, 0.0]), # Extreme peaked
48
- torch.tensor([0.5, 0.5, 0.0, 0.0]), # Two positions
49
- torch.tensor([0.25, 0.25, 0.25, 0.25]), # Uniform
50
- ]
51
-
52
- for weights in test_cases:
53
- entropy = compute_attention_entropy(weights)
54
- assert 0.0 <= entropy <= 1.0, f"Entropy {entropy} out of bounds"
55
-
56
-
57
- class TestDetectPreviousTokenHead:
58
- """Tests for detect_previous_token_head function."""
59
-
60
- def test_detects_previous_token_pattern(self, previous_token_attention_matrix, default_head_config):
61
- """Should detect matrix with strong previous-token attention."""
62
- is_prev, score = detect_previous_token_head(
63
- previous_token_attention_matrix,
64
- default_head_config
65
- )
66
-
67
- assert is_prev == True
68
- assert score > 0.5, f"Expected high score, got {score}"
69
-
70
- def test_rejects_uniform_attention(self, uniform_attention_matrix, default_head_config):
71
- """Should reject matrix with uniform attention."""
72
- is_prev, score = detect_previous_token_head(
73
- uniform_attention_matrix,
74
- default_head_config
75
- )
76
-
77
- assert is_prev == False
78
- assert score < 0.4, f"Expected low score, got {score}"
79
-
80
- def test_short_sequence_returns_false(self, default_head_config):
81
- """Sequence shorter than min_seq_len should return False."""
82
- short_matrix = torch.ones(2, 2) / 2
83
- is_prev, score = detect_previous_token_head(short_matrix, default_head_config)
84
-
85
- assert is_prev == False
86
- assert score == 0.0
87
 
 
 
88
 
89
- class TestDetectFirstTokenHead:
90
- """Tests for detect_first_token_head function."""
91
-
92
- def test_detects_first_token_pattern(self, first_token_attention_matrix, default_head_config):
93
- """Should detect matrix with strong first-token attention."""
94
- is_first, score = detect_first_token_head(
95
- first_token_attention_matrix,
96
- default_head_config
97
- )
 
 
 
 
 
 
98
 
99
- assert is_first == True
100
- assert score > 0.5, f"Expected high score, got {score}"
101
-
102
- def test_low_first_token_attention(self, default_head_config):
103
- """Matrix with low attention to first token should not be detected."""
104
- # Create matrix where first token gets very little attention
105
- # Use size 5 to be above min_seq_len and avoid overlap at [0,0]
106
- size = 5
107
- matrix = torch.zeros(size, size)
108
- for i in range(size):
109
- # Distribute attention: 5% to first token, 95% to last token
110
- matrix[i, 0] = 0.05
111
- matrix[i, -1] = 0.95
112
 
113
- is_first, score = detect_first_token_head(matrix, default_head_config)
 
 
 
 
 
114
 
115
- assert is_first == False
116
- assert score < 0.25, f"Expected low score, got {score}"
117
 
 
 
 
 
118
 
119
- class TestDetectBowHead:
120
- """Tests for detect_bow_head (bag-of-words / diffuse attention)."""
121
-
122
- def test_detects_uniform_as_bow(self, uniform_attention_matrix, default_head_config):
123
- """Uniform attention should be detected as BoW head."""
124
- is_bow, score = detect_bow_head(uniform_attention_matrix, default_head_config)
125
 
126
- # Uniform has high entropy and low max attention - should be BoW
127
- assert is_bow == True
128
- assert score > 0.9, f"Expected high entropy score, got {score}"
129
-
130
- def test_rejects_peaked_attention(self, peaked_attention_matrix, default_head_config):
131
- """Peaked attention should not be detected as BoW."""
132
- is_bow, score = detect_bow_head(peaked_attention_matrix, default_head_config)
 
 
 
133
 
134
- # Peaked attention has low entropy - should not be BoW
135
- assert is_bow == False
 
 
 
 
136
 
 
 
137
 
138
- class TestDetectSyntacticHead:
139
- """Tests for detect_syntactic_head function."""
140
-
141
- def test_consistent_distance_pattern(self, default_head_config):
142
- """Matrix with consistent distance pattern should be detected as syntactic."""
143
- # Create matrix where each position attends to position 2 tokens back
144
  size = 6
145
  matrix = torch.zeros(size, size)
146
- for i in range(size):
147
- target = max(0, i - 2) # 2 tokens back
148
- matrix[i, target] = 1.0
149
-
150
- is_syn, score = detect_syntactic_head(matrix, default_head_config)
151
-
152
- # Should have consistent distance pattern
153
- assert score > 0.0, f"Expected positive score for consistent pattern"
154
-
155
- def test_random_attention_returns_valid_values(self, default_head_config):
156
- """Random attention should return valid boolean and score."""
157
- torch.manual_seed(42)
158
- random_matrix = torch.softmax(torch.randn(6, 6), dim=-1)
159
-
160
- is_syn, score = detect_syntactic_head(random_matrix, default_head_config)
161
-
162
- # Check it returns valid types (bool or numpy bool, and numeric score)
163
- assert is_syn in [True, False] or bool(is_syn) in [True, False]
164
- assert 0 <= float(score) <= 1
165
-
166
-
167
- class TestCategorizeAttentionHead:
168
- """Tests for categorize_attention_head function."""
169
-
170
- def test_categorizes_previous_token_head(self, previous_token_attention_matrix, default_head_config):
171
- """Should categorize previous-token pattern correctly."""
172
- result = categorize_attention_head(
173
- previous_token_attention_matrix,
174
- layer_idx=0,
175
- head_idx=3,
176
- config=default_head_config
177
- )
178
-
179
- assert result['category'] == 'previous_token'
180
- assert result['layer'] == 0
181
- assert result['head'] == 3
182
- assert result['label'] == 'L0-H3'
183
- assert 'scores' in result
184
-
185
- def test_categorizes_first_token_head(self, first_token_attention_matrix, default_head_config):
186
- """Should categorize first-token pattern correctly."""
187
- result = categorize_attention_head(
188
- first_token_attention_matrix,
189
- layer_idx=2,
190
- head_idx=5,
191
- config=default_head_config
192
- )
193
 
194
- assert result['category'] == 'first_token'
195
- assert result['label'] == 'L2-H5'
196
-
197
- def test_categorizes_bow_head(self, default_head_config):
198
- """Should categorize diffuse attention as BoW when it doesn't match other patterns."""
199
- # Create BoW-like matrix: diffuse attention but first token gets LESS than threshold
200
- # This avoids triggering first_token detection (threshold 0.25)
 
 
 
 
 
201
  size = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  matrix = torch.zeros(size, size)
203
  for i in range(size):
204
- # First token gets only 0.1, rest get roughly equal share
205
- matrix[i, 0] = 0.1
206
- remaining = 0.9 / (size - 1)
207
- for j in range(1, size):
208
- matrix[i, j] = remaining
209
-
210
- result = categorize_attention_head(
211
- matrix,
212
- layer_idx=1,
213
- head_idx=0,
214
- config=default_head_config
215
- )
216
-
217
- assert result['category'] == 'bow'
218
-
219
- def test_result_structure(self, uniform_attention_matrix):
220
- """Result should have all required keys."""
221
- result = categorize_attention_head(
222
- uniform_attention_matrix,
223
- layer_idx=0,
224
- head_idx=0
225
- )
226
 
227
- required_keys = ['layer', 'head', 'category', 'scores', 'label']
228
- for key in required_keys:
229
- assert key in result, f"Missing key: {key}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
 
 
 
 
231
 
232
- class TestCategorizeAllHeads:
233
- """Tests for categorize_all_heads function."""
234
-
235
- def test_returns_all_categories(self, mock_activation_data, default_head_config):
236
- """Should return dict with all category keys."""
237
- result = categorize_all_heads(mock_activation_data, default_head_config)
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- expected_categories = ['previous_token', 'first_token', 'bow', 'syntactic', 'other']
240
- for cat in expected_categories:
241
- assert cat in result, f"Missing category: {cat}"
242
- assert isinstance(result[cat], list)
243
-
244
- def test_handles_empty_attention_data(self, default_head_config):
245
- """Should handle activation data with no attention outputs."""
246
- empty_data = {'attention_outputs': {}}
247
- result = categorize_all_heads(empty_data, default_head_config)
 
248
 
249
- # Should return empty lists for all categories
250
- for cat, heads in result.items():
251
- assert heads == []
252
-
253
-
254
- class TestFormatCategorizationSummary:
255
- """Tests for format_categorization_summary function."""
256
-
257
- def test_formats_empty_categorization(self):
258
- """Should format empty categorization without error."""
259
- empty = {
260
- 'previous_token': [],
261
- 'first_token': [],
262
- 'bow': [],
263
- 'syntactic': [],
264
- 'other': []
265
  }
266
- result = format_categorization_summary(empty)
267
-
268
- assert isinstance(result, str)
269
- assert "Total Heads: 0" in result
270
-
271
- def test_formats_with_heads(self):
272
- """Should format categorization with heads correctly."""
273
- categorized = {
274
- 'previous_token': [
275
- {'layer': 0, 'head': 1, 'label': 'L0-H1'},
276
- {'layer': 0, 'head': 2, 'label': 'L0-H2'},
277
- ],
278
- 'first_token': [
279
- {'layer': 1, 'head': 0, 'label': 'L1-H0'},
280
- ],
281
- 'bow': [],
282
- 'syntactic': [],
283
- 'other': []
284
- }
285
- result = format_categorization_summary(categorized)
286
-
287
- assert "Total Heads: 3" in result
288
- assert "Previous-Token Heads: 2" in result
289
- assert "First/Positional-Token Heads: 1" in result
290
- assert "Layer 0" in result
291
- assert "Layer 1" in result
292
-
293
-
294
- class TestHeadCategorizationConfig:
295
- """Tests for HeadCategorizationConfig defaults."""
296
-
297
- def test_default_values(self):
298
- """Default config should have reasonable values."""
299
- config = HeadCategorizationConfig()
300
 
301
- assert 0 < config.prev_token_threshold < 1
302
- assert 0 < config.first_token_threshold < 1
303
- assert 0 < config.bow_entropy_threshold < 1
304
- assert config.min_seq_len > 0
305
-
306
- def test_config_is_mutable(self):
307
- """Config values should be mutable for customization."""
308
- config = HeadCategorizationConfig()
309
- original = config.prev_token_threshold
 
310
 
311
- config.prev_token_threshold = 0.8
312
- assert config.prev_token_threshold == 0.8
313
- assert config.prev_token_threshold != original
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Tests for utils/head_detection.py
3
 
4
+ Tests the offline JSON + runtime verification head categorization system.
5
  """
6
 
7
  import pytest
8
  import torch
9
+ import json
10
  import numpy as np
11
+ from pathlib import Path
12
+ from unittest.mock import patch, mock_open
13
  from utils.head_detection import (
14
+ load_head_categories,
15
+ verify_head_activation,
16
+ get_active_head_summary,
17
+ clear_category_cache,
18
+ _compute_attention_entropy,
19
+ _find_repeated_tokens,
 
 
 
20
  )
21
 
22
 
23
+ # =============================================================================
24
+ # Sample JSON data for mocking
25
+ # =============================================================================
26
+
27
+ SAMPLE_JSON = {
28
+ "test-model": {
29
+ "model_name": "test-model",
30
+ "num_layers": 2,
31
+ "num_heads": 4,
32
+ "analysis_date": "2026-02-26",
33
+ "categories": {
34
+ "previous_token": {
35
+ "display_name": "Previous Token",
36
+ "description": "Attends to the previous token",
37
+ "educational_text": "Looks at the word before.",
38
+ "icon": "arrow-left",
39
+ "requires_repetition": False,
40
+ "top_heads": [
41
+ {"layer": 0, "head": 1, "score": 0.85},
42
+ {"layer": 1, "head": 2, "score": 0.72}
43
+ ]
44
+ },
45
+ "induction": {
46
+ "display_name": "Induction",
47
+ "description": "Pattern matching",
48
+ "educational_text": "Finds repeated patterns.",
49
+ "icon": "repeat",
50
+ "requires_repetition": True,
51
+ "suggested_prompt": "Try repeating words.",
52
+ "top_heads": [
53
+ {"layer": 1, "head": 0, "score": 0.90}
54
+ ]
55
+ },
56
+ "duplicate_token": {
57
+ "display_name": "Duplicate Token",
58
+ "description": "Finds duplicates",
59
+ "educational_text": "Spots repeated words.",
60
+ "icon": "clone",
61
+ "requires_repetition": True,
62
+ "suggested_prompt": "Try typing the same word twice.",
63
+ "top_heads": [
64
+ {"layer": 0, "head": 3, "score": 0.78}
65
+ ]
66
+ },
67
+ "positional": {
68
+ "display_name": "Positional",
69
+ "description": "First token focus",
70
+ "educational_text": "Anchors to position 0.",
71
+ "icon": "map-pin",
72
+ "requires_repetition": False,
73
+ "top_heads": [
74
+ {"layer": 0, "head": 0, "score": 0.88}
75
+ ]
76
+ },
77
+ "diffuse": {
78
+ "display_name": "Diffuse",
79
+ "description": "Spread attention",
80
+ "educational_text": "Even distribution.",
81
+ "icon": "expand-arrows-alt",
82
+ "requires_repetition": False,
83
+ "top_heads": [
84
+ {"layer": 1, "head": 3, "score": 0.80}
85
+ ]
86
+ }
87
+ },
88
+ "all_scores": {}
89
+ }
90
+ }
91
+
92
+
93
+ @pytest.fixture(autouse=True)
94
+ def clear_cache():
95
+ """Clear the category cache before each test."""
96
+ clear_category_cache()
97
+ yield
98
+ clear_category_cache()
99
+
100
+
101
+ # =============================================================================
102
+ # Tests for _compute_attention_entropy
103
+ # =============================================================================
104
+
105
  class TestComputeAttentionEntropy:
106
+ """Tests for _compute_attention_entropy helper."""
107
+
108
  def test_uniform_distribution_high_entropy(self):
109
+ """Uniform attention should have entropy near 1.0."""
110
+ weights = torch.ones(8) / 8
111
+ entropy = _compute_attention_entropy(weights)
112
+ assert entropy > 0.95
113
+
 
 
 
114
  def test_peaked_distribution_low_entropy(self):
115
+ """Peaked attention should have low entropy."""
116
+ weights = torch.zeros(8)
117
+ weights[0] = 0.98
118
+ weights[1:] = 0.02 / 7
119
+ entropy = _compute_attention_entropy(weights)
120
+ assert entropy < 0.3
121
+
122
+ def test_entropy_in_range(self):
123
+ """Entropy should always be between 0 and 1."""
124
+ for _ in range(10):
125
+ weights = torch.softmax(torch.randn(6), dim=0)
126
+ entropy = _compute_attention_entropy(weights)
127
+ assert 0.0 <= entropy <= 1.0
128
+
129
+
130
+ # =============================================================================
131
+ # Tests for _find_repeated_tokens
132
+ # =============================================================================
133
+
134
+ class TestFindRepeatedTokens:
135
+ """Tests for _find_repeated_tokens helper."""
136
+
137
+ def test_no_repeats(self):
138
+ """No repetition returns empty dict."""
139
+ assert _find_repeated_tokens([1, 2, 3, 4]) == {}
140
+
141
+ def test_simple_repeat(self):
142
+ """Repeated token returns positions."""
143
+ result = _find_repeated_tokens([10, 20, 10, 30])
144
+ assert 10 in result
145
+ assert result[10] == [0, 2]
146
+ assert 20 not in result
147
+
148
+ def test_multiple_repeats(self):
149
+ """Multiple repeated tokens tracked."""
150
+ result = _find_repeated_tokens([5, 6, 5, 6, 7])
151
+ assert 5 in result and 6 in result
152
+ assert 7 not in result
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def test_empty_input(self):
155
+ assert _find_repeated_tokens([]) == {}
156
 
157
+
158
+ # =============================================================================
159
+ # Tests for load_head_categories
160
+ # =============================================================================
161
+
162
+ class TestLoadHeadCategories:
163
+ """Tests for load_head_categories function."""
164
+
165
+ def test_loads_from_json(self, tmp_path):
166
+ """Should load model data from JSON file."""
167
+ json_file = tmp_path / "head_categories.json"
168
+ json_file.write_text(json.dumps(SAMPLE_JSON))
169
+
170
+ with patch('utils.head_detection._JSON_PATH', json_file):
171
+ result = load_head_categories("test-model")
172
 
173
+ assert result is not None
174
+ assert result["model_name"] == "test-model"
175
+ assert "previous_token" in result["categories"]
176
+
177
+ def test_returns_none_for_unknown_model(self, tmp_path):
178
+ """Should return None when model not in JSON."""
179
+ json_file = tmp_path / "head_categories.json"
180
+ json_file.write_text(json.dumps(SAMPLE_JSON))
181
+
182
+ with patch('utils.head_detection._JSON_PATH', json_file):
183
+ result = load_head_categories("nonexistent-model")
 
 
184
 
185
+ assert result is None
186
+
187
+ def test_returns_none_when_no_file(self, tmp_path):
188
+ """Should return None when JSON file doesn't exist."""
189
+ with patch('utils.head_detection._JSON_PATH', tmp_path / "missing.json"):
190
+ result = load_head_categories("test-model")
191
 
192
+ assert result is None
 
193
 
194
+ def test_caches_results(self, tmp_path):
195
+ """Should cache loaded data."""
196
+ json_file = tmp_path / "head_categories.json"
197
+ json_file.write_text(json.dumps(SAMPLE_JSON))
198
 
199
+ with patch('utils.head_detection._JSON_PATH', json_file):
200
+ result1 = load_head_categories("test-model")
201
+ # Delete file to prove cache is used
202
+ json_file.unlink()
203
+ result2 = load_head_categories("test-model")
 
204
 
205
+ assert result1 is result2
206
+
207
+ def test_short_name_alias(self, tmp_path):
208
+ """Should find model by short name (after /)."""
209
+ data = {"my-model": {"model_name": "my-model", "categories": {}}}
210
+ json_file = tmp_path / "head_categories.json"
211
+ json_file.write_text(json.dumps(data))
212
+
213
+ with patch('utils.head_detection._JSON_PATH', json_file):
214
+ result = load_head_categories("org/my-model")
215
 
216
+ assert result is not None
217
+
218
+
219
+ # =============================================================================
220
+ # Tests for verify_head_activation
221
+ # =============================================================================
222
 
223
+ class TestVerifyHeadActivation:
224
+ """Tests for verify_head_activation function."""
225
 
226
+ def test_previous_token_strong(self):
227
+ """Strong previous-token pattern should score high."""
 
 
 
 
228
  size = 6
229
  matrix = torch.zeros(size, size)
230
+ for i in range(1, size):
231
+ matrix[i, i - 1] = 0.8
232
+ matrix[i, i] = 0.2
233
+ matrix[0, 0] = 1.0
234
+
235
+ score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "previous_token")
236
+ assert score > 0.6
237
+
238
+ def test_previous_token_weak(self):
239
+ """Uniform attention should have low previous-token score."""
240
+ size = 6
241
+ matrix = torch.ones(size, size) / size
242
+ score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "previous_token")
243
+ assert score < 0.3
244
+
245
+ def test_induction_with_repetition(self):
246
+ """Induction pattern should score > 0 when repeated tokens are present."""
247
+ # Tokens: [A, B, C, A, ?] head should attend to B (position 1) from position 3
248
+ size = 5
249
+ matrix = torch.ones(size, size) / size # Baseline uniform
250
+ matrix[3, 1] = 0.7 # Position 3 (second A) attends to position 1 (B after first A)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ token_ids = [10, 20, 30, 10, 40] # Token 10 repeats
253
+ score = verify_head_activation(matrix, token_ids, "induction")
254
+ assert score > 0.3
255
+
256
+ def test_induction_no_repetition(self):
257
+ """Induction should return 0.0 when no tokens repeat."""
258
+ matrix = torch.ones(4, 4) / 4
259
+ score = verify_head_activation(matrix, [1, 2, 3, 4], "induction")
260
+ assert score == 0.0
261
+
262
+ def test_duplicate_token_with_repeats(self):
263
+ """Duplicate-token head should score > 0 when later positions attend to earlier same token."""
264
  size = 5
265
+ matrix = torch.ones(size, size) / size
266
+ matrix[3, 0] = 0.6 # Position 3 (second occurrence of token 10) attends to position 0
267
+
268
+ token_ids = [10, 20, 30, 10, 40]
269
+ score = verify_head_activation(matrix, token_ids, "duplicate_token")
270
+ assert score > 0.3
271
+
272
+ def test_duplicate_token_no_repeats(self):
273
+ """Should return 0.0 when no duplicates."""
274
+ matrix = torch.ones(4, 4) / 4
275
+ score = verify_head_activation(matrix, [1, 2, 3, 4], "duplicate_token")
276
+ assert score == 0.0
277
+
278
+ def test_positional_strong(self):
279
+ """Strong first-token attention should score high."""
280
+ size = 6
281
  matrix = torch.zeros(size, size)
282
  for i in range(size):
283
+ matrix[i, 0] = 0.7
284
+ matrix[i, i] = 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "positional")
287
+ assert score > 0.5
288
+
289
+ def test_diffuse_uniform(self):
290
+ """Uniform attention should have high diffuse score."""
291
+ size = 8
292
+ matrix = torch.ones(size, size) / size
293
+ score = verify_head_activation(matrix, list(range(size)), "diffuse")
294
+ assert score > 0.8
295
+
296
+ def test_diffuse_peaked(self):
297
+ """Peaked attention should have low diffuse score."""
298
+ size = 8
299
+ matrix = torch.zeros(size, size)
300
+ matrix[:, 0] = 1.0
301
+ score = verify_head_activation(matrix, list(range(size)), "diffuse")
302
+ assert score < 0.3
303
 
304
+ def test_unknown_category(self):
305
+ """Unknown category should return 0.0."""
306
+ matrix = torch.ones(4, 4) / 4
307
+ assert verify_head_activation(matrix, [1, 2, 3, 4], "nonexistent") == 0.0
308
 
309
+ def test_short_sequence(self):
310
+ """Very short sequence should return 0.0."""
311
+ matrix = torch.ones(1, 1)
312
+ assert verify_head_activation(matrix, [1], "previous_token") == 0.0
313
+
314
+
315
+ # =============================================================================
316
+ # Tests for get_active_head_summary
317
+ # =============================================================================
318
+
319
+ class TestGetActiveHeadSummary:
320
+ """Tests for get_active_head_summary function."""
321
+
322
+ def _make_activation_data(self, token_ids, num_layers=2, num_heads=4, seq_len=None):
323
+ """Helper: create mock activation_data with given token_ids."""
324
+ if seq_len is None:
325
+ seq_len = len(token_ids)
326
 
327
+ attention_outputs = {}
328
+ for layer in range(num_layers):
329
+ # Create uniform attention [1, num_heads, seq_len, seq_len]
330
+ attn = torch.ones(1, num_heads, seq_len, seq_len) / seq_len
331
+ attention_outputs[f'model.layers.{layer}.self_attn'] = {
332
+ 'output': [
333
+ [[0.1] * seq_len], # hidden states (unused)
334
+ attn.tolist()
335
+ ]
336
+ }
337
 
338
+ return {
339
+ 'model': 'test-model',
340
+ 'input_ids': [token_ids],
341
+ 'attention_outputs': attention_outputs,
 
 
 
 
 
 
 
 
 
 
 
 
342
  }
343
+
344
+ def test_returns_none_for_unknown_model(self, tmp_path):
345
+ """Should return None when model not in JSON."""
346
+ json_file = tmp_path / "head_categories.json"
347
+ json_file.write_text(json.dumps(SAMPLE_JSON))
348
+
349
+ with patch('utils.head_detection._JSON_PATH', json_file):
350
+ data = self._make_activation_data([1, 2, 3, 4])
351
+ result = get_active_head_summary(data, "unknown-model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ assert result is None
354
+
355
+ def test_returns_categories_structure(self, tmp_path):
356
+ """Should return proper structure with categories."""
357
+ json_file = tmp_path / "head_categories.json"
358
+ json_file.write_text(json.dumps(SAMPLE_JSON))
359
+
360
+ with patch('utils.head_detection._JSON_PATH', json_file):
361
+ data = self._make_activation_data([1, 2, 3, 4])
362
+ result = get_active_head_summary(data, "test-model")
363
 
364
+ assert result is not None
365
+ assert result["model_available"] is True
366
+ assert "categories" in result
367
+ assert "previous_token" in result["categories"]
368
+ assert "induction" in result["categories"]
369
+
370
+ def test_heads_have_activation_scores(self, tmp_path):
371
+ """Each head should have an activation_score."""
372
+ json_file = tmp_path / "head_categories.json"
373
+ json_file.write_text(json.dumps(SAMPLE_JSON))
374
+
375
+ with patch('utils.head_detection._JSON_PATH', json_file):
376
+ data = self._make_activation_data([1, 2, 3, 4])
377
+ result = get_active_head_summary(data, "test-model")
378
+
379
+ for cat_key, cat_data in result["categories"].items():
380
+ for head in cat_data.get("heads", []):
381
+ assert "activation_score" in head
382
+ assert "is_active" in head
383
+ assert "label" in head
384
+
385
+ def test_induction_grayed_when_no_repeats(self, tmp_path):
386
+ """Induction should be non-applicable when no repeated tokens."""
387
+ json_file = tmp_path / "head_categories.json"
388
+ json_file.write_text(json.dumps(SAMPLE_JSON))
389
+
390
+ with patch('utils.head_detection._JSON_PATH', json_file):
391
+ data = self._make_activation_data([1, 2, 3, 4]) # No repeats
392
+ result = get_active_head_summary(data, "test-model")
393
+
394
+ induction = result["categories"]["induction"]
395
+ assert induction["is_applicable"] is False
396
+ assert all(h["activation_score"] == 0.0 for h in induction["heads"])
397
+
398
+ def test_induction_active_with_repeats(self, tmp_path):
399
+ """Induction should be applicable when tokens repeat."""
400
+ json_file = tmp_path / "head_categories.json"
401
+ json_file.write_text(json.dumps(SAMPLE_JSON))
402
+
403
+ with patch('utils.head_detection._JSON_PATH', json_file):
404
+ data = self._make_activation_data([10, 20, 10, 30]) # Token 10 repeats
405
+ result = get_active_head_summary(data, "test-model")
406
+
407
+ induction = result["categories"]["induction"]
408
+ assert induction["is_applicable"] is True
409
+
410
+ def test_suggested_prompt_included(self, tmp_path):
411
+ """Suggested prompt should appear for repetition-dependent categories."""
412
+ json_file = tmp_path / "head_categories.json"
413
+ json_file.write_text(json.dumps(SAMPLE_JSON))
414
+
415
+ with patch('utils.head_detection._JSON_PATH', json_file):
416
+ data = self._make_activation_data([1, 2, 3, 4])
417
+ result = get_active_head_summary(data, "test-model")
418
+
419
+ assert result["categories"]["induction"]["suggested_prompt"] is not None
420
+ assert result["categories"]["duplicate_token"]["suggested_prompt"] is not None
421
+
422
+ def test_other_category_always_present(self, tmp_path):
423
+ """Other/Unclassified category should always be in the result."""
424
+ json_file = tmp_path / "head_categories.json"
425
+ json_file.write_text(json.dumps(SAMPLE_JSON))
426
+
427
+ with patch('utils.head_detection._JSON_PATH', json_file):
428
+ data = self._make_activation_data([1, 2, 3, 4])
429
+ result = get_active_head_summary(data, "test-model")
430
+
431
+ assert "other" in result["categories"]
utils/__init__.py CHANGED
@@ -8,7 +8,7 @@ from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
8
  detect_significant_probability_increases,
9
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
10
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
11
- from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
12
  from .beam_search import perform_beam_search
13
  from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
14
  from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
@@ -38,10 +38,9 @@ __all__ = [
38
  'MODEL_FAMILIES',
39
 
40
  # Head detection
41
- 'categorize_all_heads',
42
- 'categorize_single_layer_heads',
43
- 'format_categorization_summary',
44
- 'HeadCategorizationConfig',
45
 
46
  # Beam search
47
  'perform_beam_search',
 
8
  detect_significant_probability_increases,
9
  evaluate_sequence_ablation, generate_bertviz_model_view_html)
10
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
11
+ from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
12
  from .beam_search import perform_beam_search
13
  from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
14
  from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
 
38
  'MODEL_FAMILIES',
39
 
40
  # Head detection
41
+ 'load_head_categories',
42
+ 'verify_head_activation',
43
+ 'get_active_head_summary',
 
44
 
45
  # Beam search
46
  'perform_beam_search',
utils/head_categories.json ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gpt2": {
3
+ "model_name": "gpt2",
4
+ "num_layers": 12,
5
+ "num_heads": 12,
6
+ "analysis_date": "2026-02-26",
7
+ "categories": {
8
+ "previous_token": {
9
+ "display_name": "Previous Token",
10
+ "description": "Attends to the immediately preceding token \u2014 like reading left to right",
11
+ "icon": "arrow-left",
12
+ "educational_text": "This head looks at the word right before the current one. Like reading left to right, it helps track local word-by-word patterns.",
13
+ "requires_repetition": false,
14
+ "top_heads": [
15
+ {
16
+ "layer": 1,
17
+ "head": 0,
18
+ "score": 0.3655
19
+ },
20
+ {
21
+ "layer": 2,
22
+ "head": 2,
23
+ "score": 0.5679
24
+ },
25
+ {
26
+ "layer": 2,
27
+ "head": 5,
28
+ "score": 0.3384
29
+ },
30
+ {
31
+ "layer": 2,
32
+ "head": 9,
33
+ "score": 0.4052
34
+ },
35
+ {
36
+ "layer": 3,
37
+ "head": 2,
38
+ "score": 0.4164
39
+ },
40
+ {
41
+ "layer": 3,
42
+ "head": 6,
43
+ "score": 0.3359
44
+ },
45
+ {
46
+ "layer": 3,
47
+ "head": 7,
48
+ "score": 0.4419
49
+ },
50
+ {
51
+ "layer": 4,
52
+ "head": 11,
53
+ "score": 0.97
54
+ }
55
+ ]
56
+ },
57
+ "induction": {
58
+ "display_name": "Induction",
59
+ "description": "Completes repeated patterns: if it saw [A][B] before and now sees [A], it predicts [B]",
60
+ "icon": "repeat",
61
+ "educational_text": "This head finds patterns that happened before and predicts they'll happen again. If it saw 'the cat' earlier, it expects the same words to follow.",
62
+ "requires_repetition": true,
63
+ "suggested_prompt": "Try: 'The cat sat on the mat. The cat' \u2014 the repeated 'The cat' lets induction heads activate.",
64
+ "top_heads": [
65
+ {
66
+ "layer": 5,
67
+ "head": 0,
68
+ "score": 0.3363
69
+ },
70
+ {
71
+ "layer": 5,
72
+ "head": 1,
73
+ "score": 0.4412
74
+ },
75
+ {
76
+ "layer": 5,
77
+ "head": 5,
78
+ "score": 0.4119
79
+ },
80
+ {
81
+ "layer": 5,
82
+ "head": 8,
83
+ "score": 0.3032
84
+ },
85
+ {
86
+ "layer": 6,
87
+ "head": 9,
88
+ "score": 0.3017
89
+ },
90
+ {
91
+ "layer": 7,
92
+ "head": 10,
93
+ "score": 0.2849
94
+ },
95
+ {
96
+ "layer": 8,
97
+ "head": 1,
98
+ "score": 0.2608
99
+ },
100
+ {
101
+ "layer": 10,
102
+ "head": 7,
103
+ "score": 0.2196
104
+ }
105
+ ]
106
+ },
107
+ "duplicate_token": {
108
+ "display_name": "Duplicate Token",
109
+ "description": "Notices when the same word appears more than once",
110
+ "icon": "clone",
111
+ "educational_text": "This head notices when the same word appears more than once, like a highlighter for repeated words. It helps the model track which words have already been said.",
112
+ "requires_repetition": true,
113
+ "suggested_prompt": "Try a prompt with repeated words like 'The cat sat. The cat slept.' to see duplicate-token heads light up.",
114
+ "top_heads": [
115
+ {
116
+ "layer": 0,
117
+ "head": 1,
118
+ "score": 0.4175
119
+ },
120
+ {
121
+ "layer": 0,
122
+ "head": 5,
123
+ "score": 0.4155
124
+ },
125
+ {
126
+ "layer": 1,
127
+ "head": 11,
128
+ "score": 0.3256
129
+ },
130
+ {
131
+ "layer": 3,
132
+ "head": 0,
133
+ "score": 0.2416
134
+ },
135
+ {
136
+ "layer": 4,
137
+ "head": 7,
138
+ "score": 0.1238
139
+ },
140
+ {
141
+ "layer": 11,
142
+ "head": 8,
143
+ "score": 0.1741
144
+ }
145
+ ]
146
+ },
147
+ "positional": {
148
+ "display_name": "Positional / First-Token",
149
+ "description": "Always pays attention to the very first word, using it as an anchor point",
150
+ "icon": "map-pin",
151
+ "educational_text": "This head always pays attention to the very first word, using it as an anchor point. The first token serves as a 'default' position when no other token is specifically relevant.",
152
+ "requires_repetition": false,
153
+ "top_heads": [
154
+ {
155
+ "layer": 7,
156
+ "head": 2,
157
+ "score": 0.9077
158
+ },
159
+ {
160
+ "layer": 9,
161
+ "head": 6,
162
+ "score": 0.9077
163
+ },
164
+ {
165
+ "layer": 9,
166
+ "head": 9,
167
+ "score": 0.9064
168
+ },
169
+ {
170
+ "layer": 9,
171
+ "head": 11,
172
+ "score": 0.9301
173
+ },
174
+ {
175
+ "layer": 10,
176
+ "head": 10,
177
+ "score": 0.9098
178
+ },
179
+ {
180
+ "layer": 11,
181
+ "head": 2,
182
+ "score": 0.8962
183
+ },
184
+ {
185
+ "layer": 11,
186
+ "head": 6,
187
+ "score": 0.9231
188
+ },
189
+ {
190
+ "layer": 11,
191
+ "head": 9,
192
+ "score": 0.9117
193
+ }
194
+ ]
195
+ },
196
+ "diffuse": {
197
+ "display_name": "Diffuse / Spread",
198
+ "description": "Spreads attention evenly across many words, gathering general context",
199
+ "icon": "expand-arrows-alt",
200
+ "educational_text": "This head spreads its attention evenly across many words, gathering general context rather than focusing on one spot. It provides a 'big picture' summary of the input.",
201
+ "requires_repetition": false,
202
+ "top_heads": [
203
+ {
204
+ "layer": 0,
205
+ "head": 10,
206
+ "score": 0.6076
207
+ },
208
+ {
209
+ "layer": 0,
210
+ "head": 11,
211
+ "score": 0.5915
212
+ },
213
+ {
214
+ "layer": 1,
215
+ "head": 2,
216
+ "score": 0.5851
217
+ },
218
+ {
219
+ "layer": 1,
220
+ "head": 4,
221
+ "score": 0.5693
222
+ },
223
+ {
224
+ "layer": 1,
225
+ "head": 10,
226
+ "score": 0.6001
227
+ },
228
+ {
229
+ "layer": 2,
230
+ "head": 7,
231
+ "score": 0.6227
232
+ },
233
+ {
234
+ "layer": 2,
235
+ "head": 10,
236
+ "score": 0.6325
237
+ },
238
+ {
239
+ "layer": 11,
240
+ "head": 0,
241
+ "score": 0.6132
242
+ }
243
+ ]
244
+ }
245
+ },
246
+ "all_scores": {
247
+ "previous_token": [
248
+ [
249
+ 0.1650262176990509,
250
+ 0.005524545907974243,
251
+ 0.13794219493865967,
252
+ 0.11309953033924103,
253
+ 0.19386060535907745,
254
+ 0.02020726539194584,
255
+ 0.18705399334430695,
256
+ 0.3287373483181,
257
+ 0.1688501238822937,
258
+ 0.14645136892795563,
259
+ 0.12409798055887222,
260
+ 0.14697492122650146
261
+ ],
262
+ [
263
+ 0.36550161242485046,
264
+ 0.22920921444892883,
265
+ 0.1901777684688568,
266
+ 0.13691475987434387,
267
+ 0.1552433967590332,
268
+ 0.1548655927181244,
269
+ 0.14041779935359955,
270
+ 0.1399569809436798,
271
+ 0.14001941680908203,
272
+ 0.12206045538187027,
273
+ 0.18723534047603607,
274
+ 0.05272947624325752
275
+ ],
276
+ [
277
+ 0.24368862807750702,
278
+ 0.11734970659017563,
279
+ 0.5678969025611877,
280
+ 0.33175796270370483,
281
+ 0.3293865919113159,
282
+ 0.33843594789505005,
283
+ 0.1687498688697815,
284
+ 0.2169996052980423,
285
+ 0.33436763286590576,
286
+ 0.405174195766449,
287
+ 0.20988500118255615,
288
+ 0.1365954577922821
289
+ ],
290
+ [
291
+ 0.08308680355548859,
292
+ 0.16770434379577637,
293
+ 0.41642817854881287,
294
+ 0.32616299390792847,
295
+ 0.09816452860832214,
296
+ 0.12414131313562393,
297
+ 0.33591750264167786,
298
+ 0.4418589174747467,
299
+ 0.3060630261898041,
300
+ 0.21817748248577118,
301
+ 0.1548490822315216,
302
+ 0.2623787224292755
303
+ ],
304
+ [
305
+ 0.24851615726947784,
306
+ 0.22178645431995392,
307
+ 0.10810651630163193,
308
+ 0.2638419270515442,
309
+ 0.1461866945028305,
310
+ 0.19259677827358246,
311
+ 0.16893190145492554,
312
+ 0.20602412521839142,
313
+ 0.11169518530368805,
314
+ 0.16701465845108032,
315
+ 0.09775038063526154,
316
+ 0.9700173139572144
317
+ ],
318
+ [
319
+ 0.1162194162607193,
320
+ 0.09808940440416336,
321
+ 0.20977501571178436,
322
+ 0.16994376480579376,
323
+ 0.2316969633102417,
324
+ 0.10760845243930817,
325
+ 0.26810961961746216,
326
+ 0.1556214690208435,
327
+ 0.13168412446975708,
328
+ 0.10098359733819962,
329
+ 0.1563761830329895,
330
+ 0.11529763042926788
331
+ ],
332
+ [
333
+ 0.23046550154685974,
334
+ 0.13669200241565704,
335
+ 0.10113422572612762,
336
+ 0.12357200682163239,
337
+ 0.12948814034461975,
338
+ 0.14964132010936737,
339
+ 0.11104538291692734,
340
+ 0.17790208756923676,
341
+ 0.3313186764717102,
342
+ 0.09724397212266922,
343
+ 0.1065865010023117,
344
+ 0.19595712423324585
345
+ ],
346
+ [
347
+ 0.2756780683994293,
348
+ 0.09617989510297775,
349
+ 0.0887245386838913,
350
+ 0.14660504460334778,
351
+ 0.11926672607660294,
352
+ 0.12578082084655762,
353
+ 0.10664939880371094,
354
+ 0.11368991434574127,
355
+ 0.18360558152198792,
356
+ 0.130024254322052,
357
+ 0.10562390089035034,
358
+ 0.10479450225830078
359
+ ],
360
+ [
361
+ 0.10714849084615707,
362
+ 0.10390549898147583,
363
+ 0.11945408582687378,
364
+ 0.10176572948694229,
365
+ 0.15246066451072693,
366
+ 0.1935780942440033,
367
+ 0.13547158241271973,
368
+ 0.24629735946655273,
369
+ 0.14471763372421265,
370
+ 0.12072619050741196,
371
+ 0.12850022315979004,
372
+ 0.10024647414684296
373
+ ],
374
+ [
375
+ 0.1123703345656395,
376
+ 0.10224141925573349,
377
+ 0.10966678708791733,
378
+ 0.24468424916267395,
379
+ 0.09359707683324814,
380
+ 0.11123354732990265,
381
+ 0.09214123338460922,
382
+ 0.11035183817148209,
383
+ 0.09690441191196442,
384
+ 0.09199563413858414,
385
+ 0.16506430506706238,
386
+ 0.08864383399486542
387
+ ],
388
+ [
389
+ 0.09993860870599747,
390
+ 0.1017073541879654,
391
+ 0.09143912047147751,
392
+ 0.1137363463640213,
393
+ 0.11926724761724472,
394
+ 0.1261630356311798,
395
+ 0.09609334915876389,
396
+ 0.1267780214548111,
397
+ 0.09360888600349426,
398
+ 0.15695181488990784,
399
+ 0.09125342220067978,
400
+ 0.16533184051513672
401
+ ],
402
+ [
403
+ 0.1551479697227478,
404
+ 0.10182406008243561,
405
+ 0.09162592142820358,
406
+ 0.14142417907714844,
407
+ 0.10655181109905243,
408
+ 0.09299013763666153,
409
+ 0.08795793354511261,
410
+ 0.10052843391895294,
411
+ 0.18854694068431854,
412
+ 0.09097206592559814,
413
+ 0.14251284301280975,
414
+ 0.13573673367500305
415
+ ]
416
+ ],
417
+ "induction": [
418
+ [
419
+ 0.07627037912607193,
420
+ 0.0035299647133797407,
421
+ 0.050907380878925323,
422
+ 0.018350504338741302,
423
+ 0.055634528398513794,
424
+ 0.015752490609884262,
425
+ 0.09711054712533951,
426
+ 0.08642718195915222,
427
+ 0.07673756778240204,
428
+ 0.06478650867938995,
429
+ 0.05675221234560013,
430
+ 0.0686919093132019
431
+ ],
432
+ [
433
+ 0.098502516746521,
434
+ 0.08570204675197601,
435
+ 0.09086534380912781,
436
+ 0.05725013464689255,
437
+ 0.06655086576938629,
438
+ 0.08535383641719818,
439
+ 0.04390129819512367,
440
+ 0.05150846764445305,
441
+ 0.05973561853170395,
442
+ 0.05239921063184738,
443
+ 0.10886937379837036,
444
+ 0.03350156173110008
445
+ ],
446
+ [
447
+ 0.0880986899137497,
448
+ 0.029988640919327736,
449
+ 0.06596572697162628,
450
+ 0.09502042829990387,
451
+ 0.06376759707927704,
452
+ 0.07735122740268707,
453
+ 0.07770463079214096,
454
+ 0.08998467028141022,
455
+ 0.08355952054262161,
456
+ 0.08642251044511795,
457
+ 0.0951002761721611,
458
+ 0.038624998182058334
459
+ ],
460
+ [
461
+ 0.012395743280649185,
462
+ 0.0515044704079628,
463
+ 0.0702400729060173,
464
+ 0.038637131452560425,
465
+ 0.03541486710309982,
466
+ 0.04828893393278122,
467
+ 0.07664503902196884,
468
+ 0.05478388071060181,
469
+ 0.05722055584192276,
470
+ 0.05503711849451065,
471
+ 0.05377575010061264,
472
+ 0.05681142956018448
473
+ ],
474
+ [
475
+ 0.023173518478870392,
476
+ 0.04842953383922577,
477
+ 0.02587379515171051,
478
+ 0.0371115505695343,
479
+ 0.043572355061769485,
480
+ 0.025999004021286964,
481
+ 0.057220708578825,
482
+ 0.05670655891299248,
483
+ 0.05118811875581741,
484
+ 0.029776636511087418,
485
+ 0.02828892692923546,
486
+ 0.050957612693309784
487
+ ],
488
+ [
489
+ 0.3362796902656555,
490
+ 0.44116583466529846,
491
+ 0.04926660656929016,
492
+ 0.060651201754808426,
493
+ 0.049554307013750076,
494
+ 0.41194018721580505,
495
+ 0.038970425724983215,
496
+ 0.01051054522395134,
497
+ 0.30320701003074646,
498
+ 0.07053252309560776,
499
+ 0.05541849881410599,
500
+ 0.03842315822839737
501
+ ],
502
+ [
503
+ 0.04865153878927231,
504
+ 0.13892090320587158,
505
+ 0.023456398397684097,
506
+ 0.043447092175483704,
507
+ 0.05254914611577988,
508
+ 0.06307318806648254,
509
+ 0.06592734158039093,
510
+ 0.06641103327274323,
511
+ 0.06890955567359924,
512
+ 0.3017217516899109,
513
+ 0.053376901894807816,
514
+ 0.05453646928071976
515
+ ],
516
+ [
517
+ 0.04203842580318451,
518
+ 0.06195511296391487,
519
+ 0.18403273820877075,
520
+ 0.06932497024536133,
521
+ 0.025891464203596115,
522
+ 0.03674555569887161,
523
+ 0.05915430188179016,
524
+ 0.08904685080051422,
525
+ 0.029217243194580078,
526
+ 0.047680627554655075,
527
+ 0.28489723801612854,
528
+ 0.15201476216316223
529
+ ],
530
+ [
531
+ 0.03113759122788906,
532
+ 0.2607646584510803,
533
+ 0.04262052848935127,
534
+ 0.03490695357322693,
535
+ 0.020729169249534607,
536
+ 0.039468441158533096,
537
+ 0.17247121036052704,
538
+ 0.02061128057539463,
539
+ 0.0941251665353775,
540
+ 0.044258393347263336,
541
+ 0.09541143476963043,
542
+ 0.03278326988220215
543
+ ],
544
+ [
545
+ 0.06156448647379875,
546
+ 0.09029851853847504,
547
+ 0.06509305536746979,
548
+ 0.04298751801252365,
549
+ 0.02618749439716339,
550
+ 0.029909756034612656,
551
+ 0.08973383903503418,
552
+ 0.06374338269233704,
553
+ 0.02463320828974247,
554
+ 0.10424073040485382,
555
+ 0.016569094732403755,
556
+ 0.04829319566488266
557
+ ],
558
+ [
559
+ 0.0732613354921341,
560
+ 0.15449705719947815,
561
+ 0.048853177577257156,
562
+ 0.12552715837955475,
563
+ 0.1161937341094017,
564
+ 0.020513027906417847,
565
+ 0.08032035827636719,
566
+ 0.21955707669258118,
567
+ 0.07728692889213562,
568
+ 0.014143750071525574,
569
+ 0.056671954691410065,
570
+ 0.1141514927148819
571
+ ],
572
+ [
573
+ 0.10236237943172455,
574
+ 0.0509863905608654,
575
+ 0.02403058484196663,
576
+ 0.046142492443323135,
577
+ 0.03625836968421936,
578
+ 0.05091869831085205,
579
+ 0.02450958639383316,
580
+ 0.057415880262851715,
581
+ 0.09816241264343262,
582
+ 0.045323897153139114,
583
+ 0.12710919976234436,
584
+ 0.06512586772441864
585
+ ]
586
+ ],
587
+ "duplicate_token": [
588
+ [
589
+ 0.061639100313186646,
590
+ 0.4175182282924652,
591
+ 0.05723930522799492,
592
+ 0.039668913930654526,
593
+ 0.0939607322216034,
594
+ 0.41551661491394043,
595
+ 0.07361333817243576,
596
+ 0.0333673469722271,
597
+ 0.0963386595249176,
598
+ 0.0499253086745739,
599
+ 0.17845425009727478,
600
+ 0.0740630105137825
601
+ ],
602
+ [
603
+ 0.03887755423784256,
604
+ 0.03720149025321007,
605
+ 0.07625596970319748,
606
+ 0.052537791430950165,
607
+ 0.06014804169535637,
608
+ 0.09469039738178253,
609
+ 0.05574027821421623,
610
+ 0.03633364289999008,
611
+ 0.05319533869624138,
612
+ 0.04128124564886093,
613
+ 0.10213665664196014,
614
+ 0.3255976736545563
615
+ ],
616
+ [
617
+ 0.0270945243537426,
618
+ 0.02465079165995121,
619
+ 0.003460302483290434,
620
+ 0.01619820110499859,
621
+ 0.008633781224489212,
622
+ 0.012598037719726562,
623
+ 0.04559514671564102,
624
+ 0.06271781027317047,
625
+ 0.014696493744850159,
626
+ 0.012923041358590126,
627
+ 0.07460619509220123,
628
+ 0.027807259932160378
629
+ ],
630
+ [
631
+ 0.24161744117736816,
632
+ 0.013565832749009132,
633
+ 0.006801762618124485,
634
+ 0.0032485886476933956,
635
+ 0.02135937288403511,
636
+ 0.024630073457956314,
637
+ 0.015564021654427052,
638
+ 0.005436367355287075,
639
+ 0.007849231362342834,
640
+ 0.015441101975739002,
641
+ 0.04518696293234825,
642
+ 0.013415353372693062
643
+ ],
644
+ [
645
+ 0.0038080490194261074,
646
+ 0.00991421565413475,
647
+ 0.025079775601625443,
648
+ 0.011280774138867855,
649
+ 0.04912680760025978,
650
+ 0.006715251598507166,
651
+ 0.021937724202871323,
652
+ 0.12375693023204803,
653
+ 0.026765504851937294,
654
+ 0.011192137375473976,
655
+ 0.025936853140592575,
656
+ 8.196845010388643e-05
657
+ ],
658
+ [
659
+ 0.023429764434695244,
660
+ 0.016590412706136703,
661
+ 0.017092403024435043,
662
+ 0.03277356177568436,
663
+ 0.016331162303686142,
664
+ 0.021816818043589592,
665
+ 0.011733165010809898,
666
+ 0.005887174047529697,
667
+ 0.01492474414408207,
668
+ 0.030711984261870384,
669
+ 0.07108811289072037,
670
+ 0.06261330097913742
671
+ ],
672
+ [
673
+ 0.02555452659726143,
674
+ 0.029351357370615005,
675
+ 0.021288855001330376,
676
+ 0.024492312222719193,
677
+ 0.039061177521944046,
678
+ 0.03344884514808655,
679
+ 0.06831201910972595,
680
+ 0.03736294433474541,
681
+ 0.019588876515626907,
682
+ 0.04092007130384445,
683
+ 0.01721787452697754,
684
+ 0.019499698653817177
685
+ ],
686
+ [
687
+ 0.020283106714487076,
688
+ 0.02244160696864128,
689
+ 0.01908939704298973,
690
+ 0.0162697471678257,
691
+ 0.02050776034593582,
692
+ 0.02750096097588539,
693
+ 0.026029860600829124,
694
+ 0.03217357397079468,
695
+ 0.014307908713817596,
696
+ 0.006763854529708624,
697
+ 0.04564401134848595,
698
+ 0.027008097618818283
699
+ ],
700
+ [
701
+ 0.027883464470505714,
702
+ 0.041265588253736496,
703
+ 0.028905224055051804,
704
+ 0.013592107221484184,
705
+ 0.0074845412746071815,
706
+ 0.03488120436668396,
707
+ 0.04030846059322357,
708
+ 0.010207113809883595,
709
+ 0.035800714045763016,
710
+ 0.029832065105438232,
711
+ 0.02576960064470768,
712
+ 0.014182129874825478
713
+ ],
714
+ [
715
+ 0.017836367711424828,
716
+ 0.029379570856690407,
717
+ 0.022140078246593475,
718
+ 0.036215025931596756,
719
+ 0.024319598451256752,
720
+ 0.026142369955778122,
721
+ 0.018539801239967346,
722
+ 0.019365690648555756,
723
+ 0.011654431000351906,
724
+ 0.025902757421135902,
725
+ 0.015683690086007118,
726
+ 0.010347607545554638
727
+ ],
728
+ [
729
+ 0.02144056186079979,
730
+ 0.046325650066137314,
731
+ 0.021630164235830307,
732
+ 0.05147164314985275,
733
+ 0.042117439210414886,
734
+ 0.02441989816725254,
735
+ 0.02136657014489174,
736
+ 0.05447021871805191,
737
+ 0.03011142648756504,
738
+ 0.020071811974048615,
739
+ 0.016738489270210266,
740
+ 0.04836065694689751
741
+ ],
742
+ [
743
+ 0.13101476430892944,
744
+ 0.03627091646194458,
745
+ 0.0201750285923481,
746
+ 0.06851539760828018,
747
+ 0.029396140947937965,
748
+ 0.03782244399189949,
749
+ 0.014253688976168633,
750
+ 0.044284969568252563,
751
+ 0.17414367198944092,
752
+ 0.021388430148363113,
753
+ 0.06319155544042587,
754
+ 0.055135130882263184
755
+ ]
756
+ ],
757
+ "positional": [
758
+ [
759
+ 0.5065976977348328,
760
+ 0.07629109919071198,
761
+ 0.5960054397583008,
762
+ 0.1072789654135704,
763
+ 0.1979677975177765,
764
+ 0.13927273452281952,
765
+ 0.40057316422462463,
766
+ 0.294817179441452,
767
+ 0.383198618888855,
768
+ 0.5544258952140808,
769
+ 0.40033283829689026,
770
+ 0.47870078682899475
771
+ ],
772
+ [
773
+ 0.2410203516483307,
774
+ 0.4396105706691742,
775
+ 0.4307883381843567,
776
+ 0.5517755746841431,
777
+ 0.5317303538322449,
778
+ 0.5054966807365417,
779
+ 0.6495388746261597,
780
+ 0.6267575025558472,
781
+ 0.5890303254127502,
782
+ 0.6793325543403625,
783
+ 0.07594899833202362,
784
+ 0.21587026119232178
785
+ ],
786
+ [
787
+ 0.4007927477359772,
788
+ 0.7385829091072083,
789
+ 0.1999039351940155,
790
+ 0.30451780557632446,
791
+ 0.46449190378189087,
792
+ 0.3399127125740051,
793
+ 0.514499306678772,
794
+ 0.29614612460136414,
795
+ 0.31728798151016235,
796
+ 0.2615760266780853,
797
+ 0.3395046591758728,
798
+ 0.7219924926757812
799
+ ],
800
+ [
801
+ 0.8190140724182129,
802
+ 0.6275245547294617,
803
+ 0.25404971837997437,
804
+ 0.6006070375442505,
805
+ 0.8895429372787476,
806
+ 0.7170742154121399,
807
+ 0.3035760521888733,
808
+ 0.35117024183273315,
809
+ 0.4254607558250427,
810
+ 0.5432918071746826,
811
+ 0.6645973920822144,
812
+ 0.47774600982666016
813
+ ],
814
+ [
815
+ 0.5796363949775696,
816
+ 0.5921002626419067,
817
+ 0.793941080570221,
818
+ 0.49824151396751404,
819
+ 0.7273139953613281,
820
+ 0.6757563948631287,
821
+ 0.64992356300354,
822
+ 0.3122835159301758,
823
+ 0.8277088403701782,
824
+ 0.6422610878944397,
825
+ 0.8769611120223999,
826
+ 0.14915767312049866
827
+ ],
828
+ [
829
+ 0.7556132078170776,
830
+ 0.8456296920776367,
831
+ 0.6256846785545349,
832
+ 0.5377398729324341,
833
+ 0.5960881114006042,
834
+ 0.7833361625671387,
835
+ 0.723742663860321,
836
+ 0.7974669933319092,
837
+ 0.7113959789276123,
838
+ 0.8386362791061401,
839
+ 0.6537194848060608,
840
+ 0.7253992557525635
841
+ ],
842
+ [
843
+ 0.538119912147522,
844
+ 0.7342842817306519,
845
+ 0.8442155718803406,
846
+ 0.7554894685745239,
847
+ 0.6839307546615601,
848
+ 0.7064528465270996,
849
+ 0.7554677724838257,
850
+ 0.6205617189407349,
851
+ 0.5202042460441589,
852
+ 0.8443636894226074,
853
+ 0.8635346293449402,
854
+ 0.6343041062355042
855
+ ],
856
+ [
857
+ 0.6614936590194702,
858
+ 0.8791419267654419,
859
+ 0.9076933860778809,
860
+ 0.7058827877044678,
861
+ 0.8025026321411133,
862
+ 0.7749000787734985,
863
+ 0.838254451751709,
864
+ 0.8037239909172058,
865
+ 0.6864684224128723,
866
+ 0.7610327005386353,
867
+ 0.8215873837471008,
868
+ 0.8486534357070923
869
+ ],
870
+ [
871
+ 0.8073843121528625,
872
+ 0.8061873316764832,
873
+ 0.7319211959838867,
874
+ 0.8467031717300415,
875
+ 0.7768716812133789,
876
+ 0.6048685908317566,
877
+ 0.7132378816604614,
878
+ 0.6679729223251343,
879
+ 0.6701217889785767,
880
+ 0.7771828770637512,
881
+ 0.7071925401687622,
882
+ 0.8558918237686157
883
+ ],
884
+ [
885
+ 0.8133878707885742,
886
+ 0.8669012784957886,
887
+ 0.8068772554397583,
888
+ 0.5790890455245972,
889
+ 0.8904383778572083,
890
+ 0.8204380869865417,
891
+ 0.9076582789421082,
892
+ 0.7966066002845764,
893
+ 0.8762456774711609,
894
+ 0.9064305424690247,
895
+ 0.7492377758026123,
896
+ 0.9301468133926392
897
+ ],
898
+ [
899
+ 0.8455430269241333,
900
+ 0.8402767181396484,
901
+ 0.890575110912323,
902
+ 0.7642854452133179,
903
+ 0.7333279252052307,
904
+ 0.7862328290939331,
905
+ 0.8635441660881042,
906
+ 0.6658955812454224,
907
+ 0.888232409954071,
908
+ 0.7337470054626465,
909
+ 0.9097886085510254,
910
+ 0.7254845499992371
911
+ ],
912
+ [
913
+ 0.3025703728199005,
914
+ 0.8144607543945312,
915
+ 0.8962485194206238,
916
+ 0.6487042307853699,
917
+ 0.7963070869445801,
918
+ 0.8672806620597839,
919
+ 0.9231362342834473,
920
+ 0.8210302591323853,
921
+ 0.07466430962085724,
922
+ 0.9117152094841003,
923
+ 0.6209774017333984,
924
+ 0.6903347969055176
925
+ ]
926
+ ],
927
+ "diffuse": [
928
+ [
929
+ 0.5471135377883911,
930
+ 0.1322605162858963,
931
+ 0.492602676153183,
932
+ 0.21496565639972687,
933
+ 0.45495811104774475,
934
+ 0.25727584958076477,
935
+ 0.5676304697990417,
936
+ 0.5459160804748535,
937
+ 0.5383939146995544,
938
+ 0.5441114902496338,
939
+ 0.6075721383094788,
940
+ 0.5915287137031555
941
+ ],
942
+ [
943
+ 0.56114661693573,
944
+ 0.5631774663925171,
945
+ 0.5851024389266968,
946
+ 0.5447676777839661,
947
+ 0.5693410038948059,
948
+ 0.510784924030304,
949
+ 0.4271117150783539,
950
+ 0.48312950134277344,
951
+ 0.5217397212982178,
952
+ 0.4331055283546448,
953
+ 0.60009765625,
954
+ 0.3668949007987976
955
+ ],
956
+ [
957
+ 0.5618427991867065,
958
+ 0.35582321882247925,
959
+ 0.34944772720336914,
960
+ 0.5037699937820435,
961
+ 0.4152102470397949,
962
+ 0.47268810868263245,
963
+ 0.5098887085914612,
964
+ 0.622725248336792,
965
+ 0.47435516119003296,
966
+ 0.48120027780532837,
967
+ 0.6324588060379028,
968
+ 0.4027617573738098
969
+ ],
970
+ [
971
+ 0.17714563012123108,
972
+ 0.4172298312187195,
973
+ 0.42452120780944824,
974
+ 0.31828364729881287,
975
+ 0.18911775946617126,
976
+ 0.38251644372940063,
977
+ 0.5157310366630554,
978
+ 0.4105154871940613,
979
+ 0.41387349367141724,
980
+ 0.4185497760772705,
981
+ 0.40337443351745605,
982
+ 0.4543667733669281
983
+ ],
984
+ [
985
+ 0.3102322220802307,
986
+ 0.38234779238700867,
987
+ 0.3048619031906128,
988
+ 0.4123547673225403,
989
+ 0.3599177300930023,
990
+ 0.34652307629585266,
991
+ 0.447924941778183,
992
+ 0.46825671195983887,
993
+ 0.26102879643440247,
994
+ 0.3940913677215576,
995
+ 0.20296287536621094,
996
+ 0.02204204723238945
997
+ ],
998
+ [
999
+ 0.2029620110988617,
1000
+ 0.08709979802370071,
1001
+ 0.40380486845970154,
1002
+ 0.514489471912384,
1003
+ 0.4261854588985443,
1004
+ 0.1830417364835739,
1005
+ 0.26347407698631287,
1006
+ 0.2405150830745697,
1007
+ 0.2826869487762451,
1008
+ 0.24574777483940125,
1009
+ 0.3901086449623108,
1010
+ 0.3574109673500061
1011
+ ],
1012
+ [
1013
+ 0.46612176299095154,
1014
+ 0.3027900457382202,
1015
+ 0.25536319613456726,
1016
+ 0.3338863253593445,
1017
+ 0.3941308856010437,
1018
+ 0.39528438448905945,
1019
+ 0.3291165232658386,
1020
+ 0.44284719228744507,
1021
+ 0.41498908400535583,
1022
+ 0.12233757972717285,
1023
+ 0.20009461045265198,
1024
+ 0.4175761640071869
1025
+ ],
1026
+ [
1027
+ 0.33120664954185486,
1028
+ 0.1765395551919937,
1029
+ 0.09227706491947174,
1030
+ 0.37284451723098755,
1031
+ 0.2708284258842468,
1032
+ 0.31805992126464844,
1033
+ 0.25206413865089417,
1034
+ 0.21613168716430664,
1035
+ 0.3545899987220764,
1036
+ 0.3042650818824768,
1037
+ 0.14626441895961761,
1038
+ 0.1727096140384674
1039
+ ],
1040
+ [
1041
+ 0.28339847922325134,
1042
+ 0.18787869811058044,
1043
+ 0.36294665932655334,
1044
+ 0.2241670787334442,
1045
+ 0.27335819602012634,
1046
+ 0.4469229280948639,
1047
+ 0.2862758934497833,
1048
+ 0.3158189654350281,
1049
+ 0.3742186725139618,
1050
+ 0.30465927720069885,
1051
+ 0.38407495617866516,
1052
+ 0.21899032592773438
1053
+ ],
1054
+ [
1055
+ 0.23532763123512268,
1056
+ 0.16719377040863037,
1057
+ 0.2597936987876892,
1058
+ 0.4364214539527893,
1059
+ 0.17044395208358765,
1060
+ 0.2712015211582184,
1061
+ 0.13269579410552979,
1062
+ 0.2855920195579529,
1063
+ 0.18635967373847961,
1064
+ 0.1326359510421753,
1065
+ 0.29712045192718506,
1066
+ 0.11560585349798203
1067
+ ],
1068
+ [
1069
+ 0.2210322916507721,
1070
+ 0.19647784531116486,
1071
+ 0.17695878446102142,
1072
+ 0.29771047830581665,
1073
+ 0.33597418665885925,
1074
+ 0.2783747613430023,
1075
+ 0.19375105202198029,
1076
+ 0.3423268496990204,
1077
+ 0.16622166335582733,
1078
+ 0.3245820999145508,
1079
+ 0.1462937742471695,
1080
+ 0.2878214716911316
1081
+ ],
1082
+ [
1083
+ 0.6131904721260071,
1084
+ 0.2794114649295807,
1085
+ 0.18150922656059265,
1086
+ 0.42593449354171753,
1087
+ 0.31345874071121216,
1088
+ 0.20985659956932068,
1089
+ 0.14243251085281372,
1090
+ 0.2698703110218048,
1091
+ 0.5045338869094849,
1092
+ 0.15346872806549072,
1093
+ 0.4387816786766052,
1094
+ 0.3756435811519623
1095
+ ]
1096
+ ]
1097
+ }
1098
+ }
1099
+ }
utils/head_detection.py CHANGED
@@ -1,313 +1,256 @@
1
  """
2
  Attention Head Detection and Categorization
3
 
4
- Implements heuristics to categorize attention heads into:
5
- - Previous-Token Heads: high attention on previous token
6
- - First/Positional Heads: high attention on first token or positional patterns
7
- - Bag-of-Words Heads: diffuse attention on content tokens
8
- - Syntactic Heads: dependency-like patterns
 
 
 
 
9
  - Other: heads that don't fit the above categories
10
  """
11
 
 
 
12
  import torch
13
  import numpy as np
14
  from typing import Dict, List, Tuple, Optional, Any
15
  import re
 
16
 
17
 
18
- class HeadCategorizationConfig:
19
- """
20
- Configuration for attention head categorization heuristics.
21
-
22
- These thresholds are tuned to balance sensitivity (catching relevant patterns)
23
- with specificity (avoiding false positives) for educational purposes.
24
- """
25
-
26
- def __init__(self):
27
- # Previous-token head thresholds
28
- # Heads that primarily attend to the immediately preceding token
29
- self.prev_token_threshold = 0.4 # Minimum avg attention to prev token (40%)
30
- self.prev_token_diagonal_offset = 1 # Check i → i-1 pattern
31
-
32
- # First/Positional head thresholds
33
- # Heads that attend strongly to first token or show positional patterns
34
- self.first_token_threshold = 0.25 # Minimum avg attention to first token (25%)
35
- self.positional_pattern_threshold = 0.4 # For detecting positional patterns
36
-
37
- # Bag-of-words head thresholds
38
- # Heads with diffuse attention across many tokens
39
- self.bow_entropy_threshold = 0.65 # Minimum entropy (normalized, 0-1 scale)
40
- self.bow_max_attention_threshold = 0.35 # Maximum attention to any single token
41
-
42
- # Syntactic head thresholds
43
- # Heads showing structured distance patterns (e.g., subject-verb)
44
- self.syntactic_distance_pattern_threshold = 0.3 # For detecting distance patterns
45
-
46
- # General thresholds
47
- self.min_seq_len = 4 # Minimum sequence length for reliable detection
48
 
 
 
49
 
50
- def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
 
51
  """
52
- Compute normalized entropy of attention distribution.
53
 
54
  Args:
55
- attention_weights: [seq_len] tensor of attention weights for a position
56
 
57
  Returns:
58
- Normalized entropy (0 to 1)
59
- """
60
- # Avoid log(0) by adding small epsilon
61
- epsilon = 1e-10
62
- weights = attention_weights + epsilon
63
-
64
- # Compute entropy: -sum(p * log(p))
65
- entropy = -torch.sum(weights * torch.log(weights))
66
-
67
- # Normalize by max entropy (log(n) where n is sequence length)
68
- max_entropy = np.log(len(weights))
69
- normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
70
-
71
- return normalized_entropy.item()
72
-
73
-
74
- def detect_previous_token_head(attention_matrix: torch.Tensor, config: HeadCategorizationConfig) -> Tuple[bool, float]:
75
  """
76
- Detect if head shows strong previous-token pattern (i → i-1).
77
 
78
- Args:
79
- attention_matrix: [seq_len, seq_len] attention weights
80
- config: Configuration object
81
 
82
- Returns:
83
- (is_prev_token_head, score) where score is avg attention to previous token
84
- """
85
- seq_len = attention_matrix.shape[0]
86
 
87
- if seq_len < config.min_seq_len:
88
- return False, 0.0
 
 
 
89
 
90
- # Extract the diagonal offset by 1 (i → i-1 pattern)
91
- # For each position i > 0, check attention to position i-1
92
- prev_token_attentions = []
93
- for i in range(1, seq_len):
94
- prev_token_attentions.append(attention_matrix[i, i-1].item())
 
95
 
96
- avg_prev_attention = np.mean(prev_token_attentions)
97
- is_prev_token_head = avg_prev_attention >= config.prev_token_threshold
98
 
99
- return is_prev_token_head, avg_prev_attention
 
100
 
 
 
 
 
101
 
102
- def detect_first_token_head(attention_matrix: torch.Tensor, config: HeadCategorizationConfig) -> Tuple[bool, float]:
 
103
  """
104
- Detect if head shows strong attention to first token(s) or positional patterns.
105
 
106
  Args:
107
- attention_matrix: [seq_len, seq_len] attention weights
108
- config: Configuration object
109
 
110
  Returns:
111
- (is_first_token_head, score) where score is avg attention to first token
112
  """
113
- seq_len = attention_matrix.shape[0]
114
-
115
- if seq_len < config.min_seq_len:
116
- return False, 0.0
117
-
118
- # Check average attention to first token across all positions
119
- first_token_attention = attention_matrix[:, 0].mean().item()
120
- is_first_token_head = first_token_attention >= config.first_token_threshold
121
-
122
- return is_first_token_head, first_token_attention
123
 
124
 
125
- def detect_bow_head(attention_matrix: torch.Tensor, config: HeadCategorizationConfig) -> Tuple[bool, float]:
126
  """
127
- Detect if head shows bag-of-words pattern (diffuse attention).
128
 
129
  Args:
130
- attention_matrix: [seq_len, seq_len] attention weights
131
- config: Configuration object
132
 
133
  Returns:
134
- (is_bow_head, score) where score is average entropy
135
  """
136
- seq_len = attention_matrix.shape[0]
137
-
138
- if seq_len < config.min_seq_len:
139
- return False, 0.0
140
-
141
- # Compute entropy for each position's attention distribution
142
- entropies = []
143
- max_attentions = []
144
-
145
- for i in range(seq_len):
146
- entropy = compute_attention_entropy(attention_matrix[i])
147
- max_attention = attention_matrix[i].max().item()
148
-
149
- entropies.append(entropy)
150
- max_attentions.append(max_attention)
151
-
152
- avg_entropy = np.mean(entropies)
153
- avg_max_attention = np.mean(max_attentions)
154
-
155
- # BoW heads have high entropy and low max attention (diffuse)
156
- is_bow_head = (avg_entropy >= config.bow_entropy_threshold and
157
- avg_max_attention <= config.bow_max_attention_threshold)
158
-
159
- return is_bow_head, avg_entropy
160
 
161
 
162
- def detect_syntactic_head(attention_matrix: torch.Tensor, config: HeadCategorizationConfig) -> Tuple[bool, float]:
 
 
 
 
163
  """
164
- Detect if head shows syntactic/dependency-like patterns.
165
-
166
- This is a simplified heuristic based on consistent distance patterns.
167
 
168
  Args:
169
- attention_matrix: [seq_len, seq_len] attention weights
170
- config: Configuration object
 
171
 
172
  Returns:
173
- (is_syntactic_head, score) where score is pattern consistency
174
  """
175
- seq_len = attention_matrix.shape[0]
176
-
177
- if seq_len < config.min_seq_len:
178
- return False, 0.0
179
-
180
- # Check for consistent distance patterns (e.g., attending to tokens at fixed distances)
181
- # This is a simplified approach; more sophisticated syntactic detection would
182
- # require parsing or linguistic features
183
-
184
- distance_scores = []
185
-
186
- for i in range(seq_len):
187
- # For each position, find the most attended position
188
- max_idx = torch.argmax(attention_matrix[i]).item()
189
- distance = abs(i - max_idx)
 
 
 
 
190
 
191
- # Collect distances (excluding self-attention at distance 0)
192
- if distance > 0:
193
- distance_scores.append(distance)
194
-
195
- if not distance_scores:
196
- return False, 0.0
197
-
198
- # Check if there's a consistent distance pattern
199
- # (simple version: low variance in distances)
200
- distance_variance = np.var(distance_scores)
201
- distance_mean = np.mean(distance_scores)
202
 
203
- # Syntactic heads often have moderate, consistent distances
204
- # (not too short like prev-token, not too diffuse like BoW)
205
- pattern_score = 1.0 / (1.0 + distance_variance) if distance_mean > 1 else 0.0
206
- is_syntactic_head = pattern_score >= config.syntactic_distance_pattern_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- return is_syntactic_head, pattern_score
209
-
210
-
211
- def categorize_attention_head(attention_matrix: torch.Tensor,
212
- layer_idx: int,
213
- head_idx: int,
214
- config: Optional[HeadCategorizationConfig] = None) -> Dict[str, Any]:
215
- """
216
- Categorize a single attention head based on its attention pattern.
217
 
218
- Args:
219
- attention_matrix: [seq_len, seq_len] attention weights for this head
220
- layer_idx: Layer index
221
- head_idx: Head index within the layer
222
- config: Configuration object (uses defaults if None)
 
223
 
224
- Returns:
225
- Dictionary with categorization results:
226
- {
227
- 'layer': layer_idx,
228
- 'head': head_idx,
229
- 'category': str (one of: 'previous_token', 'first_token', 'bow', 'syntactic', 'other'),
230
- 'scores': dict of scores for each category,
231
- 'label': formatted label like "L{layer}-H{head}"
232
- }
233
- """
234
- if config is None:
235
- config = HeadCategorizationConfig()
236
-
237
- # Run all detection heuristics
238
- is_prev, prev_score = detect_previous_token_head(attention_matrix, config)
239
- is_first, first_score = detect_first_token_head(attention_matrix, config)
240
- is_bow, bow_score = detect_bow_head(attention_matrix, config)
241
- is_syn, syn_score = detect_syntactic_head(attention_matrix, config)
242
-
243
- # Assign category based on highest-scoring pattern
244
- # Priority: previous_token > first_token > bow > syntactic > other
245
- scores = {
246
- 'previous_token': prev_score if is_prev else 0.0,
247
- 'first_token': first_score if is_first else 0.0,
248
- 'bow': bow_score if is_bow else 0.0,
249
- 'syntactic': syn_score if is_syn else 0.0
250
- }
251
-
252
- # Determine primary category
253
- if is_prev:
254
- category = 'previous_token'
255
- elif is_first:
256
- category = 'first_token'
257
- elif is_bow:
258
- category = 'bow'
259
- elif is_syn:
260
- category = 'syntactic'
261
  else:
262
- category = 'other'
263
-
264
- return {
265
- 'layer': layer_idx,
266
- 'head': head_idx,
267
- 'category': category,
268
- 'scores': scores,
269
- 'label': f"L{layer_idx}-H{head_idx}"
270
- }
271
 
272
 
273
- def categorize_all_heads(activation_data: Dict[str, Any],
274
- config: Optional[HeadCategorizationConfig] = None) -> Dict[str, List[Dict[str, Any]]]:
 
 
275
  """
276
- Categorize all attention heads in the model.
 
277
 
278
  Args:
279
  activation_data: Output from execute_forward_pass with attention data
280
- config: Configuration object (uses defaults if None)
281
 
282
  Returns:
283
- Dictionary mapping category names to lists of head info dicts:
284
  {
285
- 'previous_token': [...],
286
- 'first_token': [...],
287
- 'bow': [...],
288
- 'syntactic': [...],
289
- 'other': [...]
 
 
 
 
 
 
 
 
 
 
 
 
290
  }
 
291
  """
292
- if config is None:
293
- config = HeadCategorizationConfig()
294
-
295
- # Initialize result dict
296
- categorized = {
297
- 'previous_token': [],
298
- 'first_token': [],
299
- 'bow': [],
300
- 'syntactic': [],
301
- 'other': []
302
- }
303
 
 
304
  attention_outputs = activation_data.get('attention_outputs', {})
305
- if not attention_outputs:
306
- return categorized
 
 
 
 
 
307
 
308
- # Process each layer's attention
309
  for module_name, output_dict in attention_outputs.items():
310
- # Extract layer number from module name
311
  numbers = re.findall(r'\d+', module_name)
312
  if not numbers:
313
  continue
@@ -318,153 +261,82 @@ def categorize_all_heads(activation_data: Dict[str, Any],
318
  if not isinstance(attention_output, list) or len(attention_output) < 2:
319
  continue
320
 
321
- # Get attention weights: [batch, heads, seq_len, seq_len]
322
  attention_weights = torch.tensor(attention_output[1])
323
-
324
- # Process each head
325
  num_heads = attention_weights.shape[1]
326
- seq_len = attention_weights.shape[2]
327
-
328
- if seq_len < config.min_seq_len:
329
- continue
330
 
331
  for head_idx in range(num_heads):
332
- # Extract attention matrix for this head: [seq_len, seq_len]
333
- head_attention = attention_weights[0, head_idx, :, :]
334
-
335
- # Categorize this head
336
- head_info = categorize_attention_head(head_attention, layer_idx, head_idx, config)
337
-
338
- # Add to appropriate category list
339
- category = head_info['category']
340
- categorized[category].append(head_info)
341
 
342
- return categorized
343
-
344
-
345
- def categorize_single_layer_heads(activation_data: Dict[str, Any],
346
- layer_num: int,
347
- config: Optional[HeadCategorizationConfig] = None) -> Dict[str, List[Dict[str, Any]]]:
348
- """
349
- Categorize attention heads for a single layer.
350
 
351
- Args:
352
- activation_data: Output from execute_forward_pass with attention data
353
- layer_num: The specific layer number to categorize
354
- config: Configuration object (uses defaults if None)
355
-
356
- Returns:
357
- Dictionary mapping category names to lists of head info dicts for this layer only:
358
- {
359
- 'previous_token': [...],
360
- 'first_token': [...],
361
- 'bow': [...],
362
- 'syntactic': [...],
363
- 'other': [...]
364
- }
365
- """
366
- if config is None:
367
- config = HeadCategorizationConfig()
368
-
369
- # Initialize result dict
370
- categorized = {
371
- 'previous_token': [],
372
- 'first_token': [],
373
- 'bow': [],
374
- 'syntactic': [],
375
- 'other': []
376
  }
377
 
378
- attention_outputs = activation_data.get('attention_outputs', {})
379
- if not attention_outputs:
380
- return categorized
381
 
382
- # Find the attention output for the requested layer
383
- target_module = None
384
- for module_name, output_dict in attention_outputs.items():
385
- # Extract layer number from module name
386
- numbers = re.findall(r'\d+', module_name)
387
- if not numbers:
388
  continue
389
 
390
- if int(numbers[0]) == layer_num:
391
- target_module = module_name
392
- break
393
-
394
- if not target_module:
395
- return categorized
396
-
397
- output_dict = attention_outputs[target_module]
398
- attention_output = output_dict.get('output')
399
-
400
- if not isinstance(attention_output, list) or len(attention_output) < 2:
401
- return categorized
402
-
403
- # Get attention weights: [batch, heads, seq_len, seq_len]
404
- attention_weights = torch.tensor(attention_output[1])
405
-
406
- # Process each head
407
- num_heads = attention_weights.shape[1]
408
- seq_len = attention_weights.shape[2]
409
-
410
- if seq_len < config.min_seq_len:
411
- return categorized
412
-
413
- for head_idx in range(num_heads):
414
- # Extract attention matrix for this head: [seq_len, seq_len]
415
- head_attention = attention_weights[0, head_idx, :, :]
416
 
417
- # Categorize this head
418
- head_info = categorize_attention_head(head_attention, layer_num, head_idx, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- # Add to appropriate category list
421
- category = head_info['category']
422
- categorized[category].append(head_info)
423
-
424
- return categorized
425
-
426
-
427
- def format_categorization_summary(categorized_heads: Dict[str, List[Dict[str, Any]]]) -> str:
428
- """
429
- Format categorization results as a human-readable summary.
430
-
431
- Args:
432
- categorized_heads: Output from categorize_all_heads or categorize_single_layer_heads
433
 
434
- Returns:
435
- Formatted string summary
436
- """
437
- category_names = {
438
- 'previous_token': 'Previous-Token Heads',
439
- 'first_token': 'First/Positional-Token Heads',
440
- 'bow': 'Bag-of-Words Heads',
441
- 'syntactic': 'Syntactic Heads',
442
- 'other': 'Other Heads'
 
443
  }
444
 
445
- summary = []
446
- total_heads = sum(len(heads) for heads in categorized_heads.values())
447
-
448
- summary.append(f"Total Heads: {total_heads}\n")
449
- summary.append("=" * 60)
450
-
451
- for category, display_name in category_names.items():
452
- heads = categorized_heads.get(category, [])
453
- summary.append(f"\n{display_name}: {len(heads)} heads")
454
-
455
- if heads:
456
- # Group by layer
457
- heads_by_layer = {}
458
- for head_info in heads:
459
- layer = head_info['layer']
460
- if layer not in heads_by_layer:
461
- heads_by_layer[layer] = []
462
- heads_by_layer[layer].append(head_info['head'])
463
-
464
- # Format by layer
465
- for layer in sorted(heads_by_layer.keys()):
466
- head_indices = sorted(heads_by_layer[layer])
467
- summary.append(f" Layer {layer}: Heads {head_indices}")
468
-
469
- return "\n".join(summary)
470
-
 
1
  """
2
  Attention Head Detection and Categorization
3
 
4
+ Loads pre-computed head category data from JSON (produced by scripts/analyze_heads.py)
5
+ and performs lightweight runtime verification of head activation on the current input.
6
+
7
+ Categories:
8
+ - Previous Token: attends to the immediately preceding token
9
+ - Induction: completes repeated patterns ([A][B]...[A] → [B])
10
+ - Duplicate Token: attends to earlier occurrences of the same token
11
+ - Positional / First-Token: attends to the first token or positional patterns
12
+ - Diffuse / Spread: high-entropy, evenly distributed attention
13
  - Other: heads that don't fit the above categories
14
  """
15
 
16
+ import json
17
+ import os
18
  import torch
19
  import numpy as np
20
  from typing import Dict, List, Tuple, Optional, Any
21
  import re
22
+ from pathlib import Path
23
 
24
 
25
+ # Path to the pre-computed head categories JSON
26
+ _JSON_PATH = Path(__file__).parent / "head_categories.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Cache for loaded JSON data (avoids re-reading per request)
29
+ _category_cache: Dict[str, Any] = {}
30
 
31
+
32
+ def load_head_categories(model_name: str) -> Optional[Dict[str, Any]]:
33
  """
34
+ Load pre-computed head category data for a model.
35
 
36
  Args:
37
+ model_name: HuggingFace model name (e.g., "gpt2", "EleutherAI/pythia-70m")
38
 
39
  Returns:
40
+ Dict with model's category data, or None if model not analyzed.
41
+ Structure: {
42
+ "model_name": str,
43
+ "num_layers": int,
44
+ "num_heads": int,
45
+ "categories": { category_name: { "top_heads": [...], ... } },
46
+ ...
47
+ }
 
 
 
 
 
 
 
 
 
48
  """
49
+ global _category_cache
50
 
51
+ # Check cache first
52
+ if model_name in _category_cache:
53
+ return _category_cache[model_name]
54
 
55
+ # Load JSON
56
+ if not _JSON_PATH.exists():
57
+ return None
 
58
 
59
+ try:
60
+ with open(_JSON_PATH, 'r') as f:
61
+ all_data = json.load(f)
62
+ except (json.JSONDecodeError, IOError):
63
+ return None
64
 
65
+ # Try exact match first, then common aliases
66
+ model_data = all_data.get(model_name)
67
+ if model_data is None:
68
+ # Try short name (e.g., "gpt2" for "openai-community/gpt2")
69
+ short_name = model_name.split('/')[-1] if '/' in model_name else model_name
70
+ model_data = all_data.get(short_name)
71
 
72
+ if model_data is not None:
73
+ _category_cache[model_name] = model_data
74
 
75
+ return model_data
76
+
77
 
78
+ def clear_category_cache():
79
+ """Clear the loaded category cache (useful for testing)."""
80
+ global _category_cache
81
+ _category_cache = {}
82
 
83
+
84
+ def _compute_attention_entropy(attention_weights: torch.Tensor) -> float:
85
  """
86
+ Compute normalized entropy of an attention distribution.
87
 
88
  Args:
89
+ attention_weights: [seq_len] tensor of attention weights for one position
 
90
 
91
  Returns:
92
+ Normalized entropy (0.0 to 1.0). 1.0 = perfectly uniform, 0.0 = fully peaked.
93
  """
94
+ epsilon = 1e-10
95
+ weights = attention_weights + epsilon
96
+ entropy = -torch.sum(weights * torch.log(weights))
97
+ max_entropy = np.log(len(weights))
98
+ return (entropy / max_entropy).item() if max_entropy > 0 else 0.0
 
 
 
 
 
99
 
100
 
101
+ def _find_repeated_tokens(token_ids: List[int]) -> Dict[int, List[int]]:
102
  """
103
+ Find tokens that appear more than once and their positions.
104
 
105
  Args:
106
+ token_ids: List of token IDs in the sequence
 
107
 
108
  Returns:
109
+ Dict mapping token_id -> list of positions where it appears (only for repeated tokens)
110
  """
111
+ positions: Dict[int, List[int]] = {}
112
+ for i, tid in enumerate(token_ids):
113
+ if tid not in positions:
114
+ positions[tid] = []
115
+ positions[tid].append(i)
116
+
117
+ # Keep only tokens that appear more than once
118
+ return {tid: pos_list for tid, pos_list in positions.items() if len(pos_list) > 1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
+ def verify_head_activation(
122
+ attn_matrix: torch.Tensor,
123
+ token_ids: List[int],
124
+ category: str
125
+ ) -> float:
126
  """
127
+ Verify whether a head's known role is active on the current input.
 
 
128
 
129
  Args:
130
+ attn_matrix: [seq_len, seq_len] attention weights for this head
131
+ token_ids: List of token IDs in the input
132
+ category: Category name (previous_token, induction, duplicate_token, positional, diffuse)
133
 
134
  Returns:
135
+ Activation score (0.0 to 1.0). 0.0 means the role is not triggered on this input.
136
  """
137
+ seq_len = attn_matrix.shape[0]
138
+
139
+ if seq_len < 2:
140
+ return 0.0
141
+
142
+ if category == "previous_token":
143
+ # Mean of diagonal-1 values: how much each position attends to the previous position
144
+ prev_token_attentions = []
145
+ for i in range(1, seq_len):
146
+ prev_token_attentions.append(attn_matrix[i, i - 1].item())
147
+ return float(np.mean(prev_token_attentions)) if prev_token_attentions else 0.0
148
+
149
+ elif category == "induction":
150
+ # Induction pattern: [A][B]...[A] → attend to [B]
151
+ # For each repeated token at position i where token[i]==token[j] (j < i),
152
+ # check if position i attends to position j+1
153
+ repeated = _find_repeated_tokens(token_ids)
154
+ if not repeated:
155
+ return 0.0 # No repetition → gray out
156
 
157
+ induction_scores = []
158
+ for tid, positions in repeated.items():
159
+ for k in range(1, len(positions)):
160
+ current_pos = positions[k] # Later occurrence
161
+ for prev_idx in range(k):
162
+ prev_pos = positions[prev_idx] # Earlier occurrence
163
+ target_pos = prev_pos + 1 # The token AFTER the earlier occurrence
164
+ if target_pos < seq_len and current_pos < seq_len:
165
+ induction_scores.append(attn_matrix[current_pos, target_pos].item())
166
+
167
+ return float(np.mean(induction_scores)) if induction_scores else 0.0
168
 
169
+ elif category == "duplicate_token":
170
+ # Check if later occurrences attend to earlier occurrences of the same token
171
+ repeated = _find_repeated_tokens(token_ids)
172
+ if not repeated:
173
+ return 0.0 # No duplicates → gray out
174
+
175
+ dup_scores = []
176
+ for tid, positions in repeated.items():
177
+ for k in range(1, len(positions)):
178
+ later_pos = positions[k]
179
+ # Sum attention to all earlier occurrences
180
+ earlier_attention = sum(
181
+ attn_matrix[later_pos, positions[j]].item()
182
+ for j in range(k)
183
+ )
184
+ dup_scores.append(earlier_attention)
185
+
186
+ return float(np.mean(dup_scores)) if dup_scores else 0.0
187
 
188
+ elif category == "positional":
189
+ # Mean of column-0 attention (how much each position attends to the first token)
190
+ first_token_attention = attn_matrix[:, 0].mean().item()
191
+ return first_token_attention
 
 
 
 
 
192
 
193
+ elif category == "diffuse":
194
+ # Average normalized entropy across all positions
195
+ entropies = []
196
+ for i in range(seq_len):
197
+ entropies.append(_compute_attention_entropy(attn_matrix[i]))
198
+ return float(np.mean(entropies)) if entropies else 0.0
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  else:
201
+ return 0.0
 
 
 
 
 
 
 
 
202
 
203
 
204
+ def get_active_head_summary(
205
+ activation_data: Dict[str, Any],
206
+ model_name: str
207
+ ) -> Optional[Dict[str, Any]]:
208
  """
209
+ Main entry point: load categories from JSON, verify each head on the current input,
210
+ and return a UI-ready structure.
211
 
212
  Args:
213
  activation_data: Output from execute_forward_pass with attention data
214
+ model_name: HuggingFace model name
215
 
216
  Returns:
217
+ Dict with structure:
218
  {
219
+ "model_available": True,
220
+ "categories": {
221
+ "previous_token": {
222
+ "display_name": str,
223
+ "description": str,
224
+ "educational_text": str,
225
+ "icon": str,
226
+ "requires_repetition": bool,
227
+ "suggested_prompt": str or None,
228
+ "is_applicable": bool, # False if requires_repetition but no repeats
229
+ "heads": [
230
+ {"layer": int, "head": int, "offline_score": float,
231
+ "activation_score": float, "is_active": bool, "label": str}
232
+ ]
233
+ },
234
+ ...
235
+ }
236
  }
237
+ Returns None if model not in JSON.
238
  """
239
+ model_data = load_head_categories(model_name)
240
+ if model_data is None:
241
+ return None
 
 
 
 
 
 
 
 
242
 
243
+ # Extract attention weights and token IDs from activation data
244
  attention_outputs = activation_data.get('attention_outputs', {})
245
+ input_ids = activation_data.get('input_ids', [[]])[0]
246
+
247
+ if not attention_outputs or not input_ids:
248
+ return None
249
+
250
+ # Build a lookup: (layer, head) → attention_matrix [seq_len, seq_len]
251
+ head_attention_lookup: Dict[Tuple[int, int], torch.Tensor] = {}
252
 
 
253
  for module_name, output_dict in attention_outputs.items():
 
254
  numbers = re.findall(r'\d+', module_name)
255
  if not numbers:
256
  continue
 
261
  if not isinstance(attention_output, list) or len(attention_output) < 2:
262
  continue
263
 
264
+ # attention_output[1] is [batch, heads, seq_len, seq_len]
265
  attention_weights = torch.tensor(attention_output[1])
 
 
266
  num_heads = attention_weights.shape[1]
 
 
 
 
267
 
268
  for head_idx in range(num_heads):
269
+ head_attention_lookup[(layer_idx, head_idx)] = attention_weights[0, head_idx, :, :]
 
 
 
 
 
 
 
 
270
 
271
+ # Check if input has repeated tokens (needed for applicability check)
272
+ repeated_tokens = _find_repeated_tokens(input_ids)
273
+ has_repetition = len(repeated_tokens) > 0
 
 
 
 
 
274
 
275
+ # Build result
276
+ result = {
277
+ "model_available": True,
278
+ "categories": {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  }
280
 
281
+ categories = model_data.get("categories", {})
 
 
282
 
283
+ # Define category order for consistent display
284
+ category_order = ["previous_token", "induction", "duplicate_token", "positional", "diffuse"]
285
+
286
+ for cat_key in category_order:
287
+ cat_info = categories.get(cat_key)
288
+ if cat_info is None:
289
  continue
290
 
291
+ requires_repetition = cat_info.get("requires_repetition", False)
292
+ is_applicable = not requires_repetition or has_repetition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
+ heads_result = []
295
+ for head_entry in cat_info.get("top_heads", []):
296
+ layer = head_entry["layer"]
297
+ head = head_entry["head"]
298
+ offline_score = head_entry["score"]
299
+
300
+ # Get activation score on current input
301
+ attn_matrix = head_attention_lookup.get((layer, head))
302
+ if attn_matrix is not None and is_applicable:
303
+ activation_score = verify_head_activation(attn_matrix, input_ids, cat_key)
304
+ else:
305
+ activation_score = 0.0
306
+
307
+ # A head is "active" if its activation score exceeds a minimum threshold
308
+ is_active = activation_score > 0.1 and is_applicable
309
+
310
+ heads_result.append({
311
+ "layer": layer,
312
+ "head": head,
313
+ "offline_score": offline_score,
314
+ "activation_score": round(activation_score, 3),
315
+ "is_active": is_active,
316
+ "label": f"L{layer}-H{head}"
317
+ })
318
 
319
+ result["categories"][cat_key] = {
320
+ "display_name": cat_info.get("display_name", cat_key),
321
+ "description": cat_info.get("description", ""),
322
+ "educational_text": cat_info.get("educational_text", ""),
323
+ "icon": cat_info.get("icon", "circle"),
324
+ "requires_repetition": requires_repetition,
325
+ "suggested_prompt": cat_info.get("suggested_prompt"),
326
+ "is_applicable": is_applicable,
327
+ "heads": heads_result
328
+ }
 
 
 
329
 
330
+ # Add "Other" category (heads not claimed by any top list)
331
+ result["categories"]["other"] = {
332
+ "display_name": "Other / Unclassified",
333
+ "description": "Heads whose patterns don't fit the simple categories above",
334
+ "educational_text": "This head's pattern doesn't fit our simple categories — it may be doing something more complex or context-dependent.",
335
+ "icon": "question-circle",
336
+ "requires_repetition": False,
337
+ "suggested_prompt": None,
338
+ "is_applicable": True,
339
+ "heads": [] # We don't enumerate all "other" heads to keep the UI clean
340
  }
341
 
342
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model_patterns.py CHANGED
@@ -1421,23 +1421,4 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
1421
  return f"<p>Error generating visualization: {str(e)}</p>"
1422
 
1423
 
1424
- def get_head_category_counts(activation_data: Dict[str, Any]) -> Dict[str, int]:
1425
- """
1426
- Get counts of attention heads in each category.
1427
-
1428
- Useful for UI display showing the distribution of head types.
1429
-
1430
- Args:
1431
- activation_data: Output from execute_forward_pass with attention data
1432
-
1433
- Returns:
1434
- Dict mapping category name to count of heads in that category
1435
- """
1436
- from .head_detection import categorize_all_heads
1437
-
1438
- try:
1439
- categories = categorize_all_heads(activation_data)
1440
- return {category: len(heads) for category, heads in categories.items()}
1441
- except Exception as e:
1442
- print(f"Warning: Could not categorize heads: {e}")
1443
- return {}
 
1421
  return f"<p>Error generating visualization: {str(e)}</p>"
1422
 
1423
 
1424
+