LLMVis / utils /head_detection.py
cdpearlman's picture
Attention refactor, better categorization and explanation
7fa8fb4
"""
Attention Head Detection and Categorization
Loads pre-computed head category data from JSON (produced by scripts/analyze_heads.py)
and performs lightweight runtime verification of head activation on the current input.
Categories:
- Previous Token: attends to the immediately preceding token
- Induction: completes repeated patterns ([A][B]...[A] → [B])
- Duplicate Token: attends to earlier occurrences of the same token
- Positional / First-Token: attends to the first token or positional patterns
- Diffuse / Spread: high-entropy, evenly distributed attention
- Other: heads that don't fit the above categories
"""
import json
import os
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
import re
from pathlib import Path
# Path to the pre-computed head categories JSON
_JSON_PATH = Path(__file__).parent / "head_categories.json"
# Cache for loaded JSON data (avoids re-reading per request)
_category_cache: Dict[str, Any] = {}
def load_head_categories(model_name: str) -> Optional[Dict[str, Any]]:
"""
Load pre-computed head category data for a model.
Args:
model_name: HuggingFace model name (e.g., "gpt2", "EleutherAI/pythia-70m")
Returns:
Dict with model's category data, or None if model not analyzed.
Structure: {
"model_name": str,
"num_layers": int,
"num_heads": int,
"categories": { category_name: { "top_heads": [...], ... } },
...
}
"""
global _category_cache
# Check cache first
if model_name in _category_cache:
return _category_cache[model_name]
# Load JSON
if not _JSON_PATH.exists():
return None
try:
with open(_JSON_PATH, 'r') as f:
all_data = json.load(f)
except (json.JSONDecodeError, IOError):
return None
# Try exact match first, then common aliases
model_data = all_data.get(model_name)
if model_data is None:
# Try short name (e.g., "gpt2" for "openai-community/gpt2")
short_name = model_name.split('/')[-1] if '/' in model_name else model_name
model_data = all_data.get(short_name)
if model_data is not None:
_category_cache[model_name] = model_data
return model_data
def clear_category_cache():
"""Clear the loaded category cache (useful for testing)."""
global _category_cache
_category_cache = {}
def _compute_attention_entropy(attention_weights: torch.Tensor) -> float:
"""
Compute normalized entropy of an attention distribution.
Args:
attention_weights: [seq_len] tensor of attention weights for one position
Returns:
Normalized entropy (0.0 to 1.0). 1.0 = perfectly uniform, 0.0 = fully peaked.
"""
epsilon = 1e-10
weights = attention_weights + epsilon
entropy = -torch.sum(weights * torch.log(weights))
max_entropy = np.log(len(weights))
return (entropy / max_entropy).item() if max_entropy > 0 else 0.0
def _find_repeated_tokens(token_ids: List[int]) -> Dict[int, List[int]]:
"""
Find tokens that appear more than once and their positions.
Args:
token_ids: List of token IDs in the sequence
Returns:
Dict mapping token_id -> list of positions where it appears (only for repeated tokens)
"""
positions: Dict[int, List[int]] = {}
for i, tid in enumerate(token_ids):
if tid not in positions:
positions[tid] = []
positions[tid].append(i)
# Keep only tokens that appear more than once
return {tid: pos_list for tid, pos_list in positions.items() if len(pos_list) > 1}
def verify_head_activation(
attn_matrix: torch.Tensor,
token_ids: List[int],
category: str
) -> float:
"""
Verify whether a head's known role is active on the current input.
Args:
attn_matrix: [seq_len, seq_len] attention weights for this head
token_ids: List of token IDs in the input
category: Category name (previous_token, induction, duplicate_token, positional, diffuse)
Returns:
Activation score (0.0 to 1.0). 0.0 means the role is not triggered on this input.
"""
seq_len = attn_matrix.shape[0]
if seq_len < 2:
return 0.0
if category == "previous_token":
# Mean of diagonal-1 values: how much each position attends to the previous position
prev_token_attentions = []
for i in range(1, seq_len):
prev_token_attentions.append(attn_matrix[i, i - 1].item())
return float(np.mean(prev_token_attentions)) if prev_token_attentions else 0.0
elif category == "induction":
# Induction pattern: [A][B]...[A] → attend to [B]
# For each repeated token at position i where token[i]==token[j] (j < i),
# check if position i attends to position j+1
repeated = _find_repeated_tokens(token_ids)
if not repeated:
return 0.0 # No repetition → gray out
induction_scores = []
for tid, positions in repeated.items():
for k in range(1, len(positions)):
current_pos = positions[k] # Later occurrence
for prev_idx in range(k):
prev_pos = positions[prev_idx] # Earlier occurrence
target_pos = prev_pos + 1 # The token AFTER the earlier occurrence
if target_pos < seq_len and current_pos < seq_len:
induction_scores.append(attn_matrix[current_pos, target_pos].item())
return float(np.mean(induction_scores)) if induction_scores else 0.0
elif category == "duplicate_token":
# Check if later occurrences attend to earlier occurrences of the same token
repeated = _find_repeated_tokens(token_ids)
if not repeated:
return 0.0 # No duplicates → gray out
dup_scores = []
for tid, positions in repeated.items():
for k in range(1, len(positions)):
later_pos = positions[k]
# Sum attention to all earlier occurrences
earlier_attention = sum(
attn_matrix[later_pos, positions[j]].item()
for j in range(k)
)
dup_scores.append(earlier_attention)
return float(np.mean(dup_scores)) if dup_scores else 0.0
elif category == "positional":
# Mean of column-0 attention (how much each position attends to the first token)
first_token_attention = attn_matrix[:, 0].mean().item()
return first_token_attention
elif category == "diffuse":
# Average normalized entropy across all positions
entropies = []
for i in range(seq_len):
entropies.append(_compute_attention_entropy(attn_matrix[i]))
return float(np.mean(entropies)) if entropies else 0.0
else:
return 0.0
def get_active_head_summary(
activation_data: Dict[str, Any],
model_name: str
) -> Optional[Dict[str, Any]]:
"""
Main entry point: load categories from JSON, verify each head on the current input,
and return a UI-ready structure.
Args:
activation_data: Output from execute_forward_pass with attention data
model_name: HuggingFace model name
Returns:
Dict with structure:
{
"model_available": True,
"categories": {
"previous_token": {
"display_name": str,
"description": str,
"educational_text": str,
"icon": str,
"requires_repetition": bool,
"suggested_prompt": str or None,
"is_applicable": bool, # False if requires_repetition but no repeats
"heads": [
{"layer": int, "head": int, "offline_score": float,
"activation_score": float, "is_active": bool, "label": str}
]
},
...
}
}
Returns None if model not in JSON.
"""
model_data = load_head_categories(model_name)
if model_data is None:
return None
# Extract attention weights and token IDs from activation data
attention_outputs = activation_data.get('attention_outputs', {})
input_ids = activation_data.get('input_ids', [[]])[0]
if not attention_outputs or not input_ids:
return None
# Build a lookup: (layer, head) → attention_matrix [seq_len, seq_len]
head_attention_lookup: Dict[Tuple[int, int], torch.Tensor] = {}
for module_name, output_dict in attention_outputs.items():
numbers = re.findall(r'\d+', module_name)
if not numbers:
continue
layer_idx = int(numbers[0])
attention_output = output_dict.get('output')
if not isinstance(attention_output, list) or len(attention_output) < 2:
continue
# attention_output[1] is [batch, heads, seq_len, seq_len]
attention_weights = torch.tensor(attention_output[1])
num_heads = attention_weights.shape[1]
for head_idx in range(num_heads):
head_attention_lookup[(layer_idx, head_idx)] = attention_weights[0, head_idx, :, :]
# Check if input has repeated tokens (needed for applicability check)
repeated_tokens = _find_repeated_tokens(input_ids)
has_repetition = len(repeated_tokens) > 0
# Build result
result = {
"model_available": True,
"categories": {}
}
categories = model_data.get("categories", {})
# Define category order for consistent display
category_order = ["previous_token", "induction", "duplicate_token", "positional", "diffuse"]
for cat_key in category_order:
cat_info = categories.get(cat_key)
if cat_info is None:
continue
requires_repetition = cat_info.get("requires_repetition", False)
is_applicable = not requires_repetition or has_repetition
heads_result = []
for head_entry in cat_info.get("top_heads", []):
layer = head_entry["layer"]
head = head_entry["head"]
offline_score = head_entry["score"]
# Get activation score on current input
attn_matrix = head_attention_lookup.get((layer, head))
if attn_matrix is not None and is_applicable:
activation_score = verify_head_activation(attn_matrix, input_ids, cat_key)
else:
activation_score = 0.0
# A head is "active" if its activation score exceeds a minimum threshold
is_active = activation_score > 0.1 and is_applicable
heads_result.append({
"layer": layer,
"head": head,
"offline_score": offline_score,
"activation_score": round(activation_score, 3),
"is_active": is_active,
"label": f"L{layer}-H{head}"
})
result["categories"][cat_key] = {
"display_name": cat_info.get("display_name", cat_key),
"description": cat_info.get("description", ""),
"educational_text": cat_info.get("educational_text", ""),
"icon": cat_info.get("icon", "circle"),
"requires_repetition": requires_repetition,
"suggested_prompt": cat_info.get("suggested_prompt"),
"is_applicable": is_applicable,
"heads": heads_result
}
# Add "Other" category (heads not claimed by any top list)
result["categories"]["other"] = {
"display_name": "Other / Unclassified",
"description": "Heads whose patterns don't fit the simple categories above",
"educational_text": "This head's pattern doesn't fit our simple categories — it may be doing something more complex or context-dependent.",
"icon": "question-circle",
"requires_repetition": False,
"suggested_prompt": None,
"is_applicable": True,
"heads": [] # We don't enumerate all "other" heads to keep the UI clean
}
return result