Spaces:
Running
Running
| """ | |
| Attention Head Detection and Categorization | |
| Loads pre-computed head category data from JSON (produced by scripts/analyze_heads.py) | |
| and performs lightweight runtime verification of head activation on the current input. | |
| Categories: | |
| - Previous Token: attends to the immediately preceding token | |
| - Induction: completes repeated patterns ([A][B]...[A] → [B]) | |
| - Duplicate Token: attends to earlier occurrences of the same token | |
| - Positional / First-Token: attends to the first token or positional patterns | |
| - Diffuse / Spread: high-entropy, evenly distributed attention | |
| - Other: heads that don't fit the above categories | |
| """ | |
| import json | |
| import os | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional, Any | |
| import re | |
| from pathlib import Path | |
| # Path to the pre-computed head categories JSON | |
| _JSON_PATH = Path(__file__).parent / "head_categories.json" | |
| # Cache for loaded JSON data (avoids re-reading per request) | |
| _category_cache: Dict[str, Any] = {} | |
| def load_head_categories(model_name: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Load pre-computed head category data for a model. | |
| Args: | |
| model_name: HuggingFace model name (e.g., "gpt2", "EleutherAI/pythia-70m") | |
| Returns: | |
| Dict with model's category data, or None if model not analyzed. | |
| Structure: { | |
| "model_name": str, | |
| "num_layers": int, | |
| "num_heads": int, | |
| "categories": { category_name: { "top_heads": [...], ... } }, | |
| ... | |
| } | |
| """ | |
| global _category_cache | |
| # Check cache first | |
| if model_name in _category_cache: | |
| return _category_cache[model_name] | |
| # Load JSON | |
| if not _JSON_PATH.exists(): | |
| return None | |
| try: | |
| with open(_JSON_PATH, 'r') as f: | |
| all_data = json.load(f) | |
| except (json.JSONDecodeError, IOError): | |
| return None | |
| # Try exact match first, then common aliases | |
| model_data = all_data.get(model_name) | |
| if model_data is None: | |
| # Try short name (e.g., "gpt2" for "openai-community/gpt2") | |
| short_name = model_name.split('/')[-1] if '/' in model_name else model_name | |
| model_data = all_data.get(short_name) | |
| if model_data is not None: | |
| _category_cache[model_name] = model_data | |
| return model_data | |
| def clear_category_cache(): | |
| """Clear the loaded category cache (useful for testing).""" | |
| global _category_cache | |
| _category_cache = {} | |
| def _compute_attention_entropy(attention_weights: torch.Tensor) -> float: | |
| """ | |
| Compute normalized entropy of an attention distribution. | |
| Args: | |
| attention_weights: [seq_len] tensor of attention weights for one position | |
| Returns: | |
| Normalized entropy (0.0 to 1.0). 1.0 = perfectly uniform, 0.0 = fully peaked. | |
| """ | |
| epsilon = 1e-10 | |
| weights = attention_weights + epsilon | |
| entropy = -torch.sum(weights * torch.log(weights)) | |
| max_entropy = np.log(len(weights)) | |
| return (entropy / max_entropy).item() if max_entropy > 0 else 0.0 | |
| def _find_repeated_tokens(token_ids: List[int]) -> Dict[int, List[int]]: | |
| """ | |
| Find tokens that appear more than once and their positions. | |
| Args: | |
| token_ids: List of token IDs in the sequence | |
| Returns: | |
| Dict mapping token_id -> list of positions where it appears (only for repeated tokens) | |
| """ | |
| positions: Dict[int, List[int]] = {} | |
| for i, tid in enumerate(token_ids): | |
| if tid not in positions: | |
| positions[tid] = [] | |
| positions[tid].append(i) | |
| # Keep only tokens that appear more than once | |
| return {tid: pos_list for tid, pos_list in positions.items() if len(pos_list) > 1} | |
| def verify_head_activation( | |
| attn_matrix: torch.Tensor, | |
| token_ids: List[int], | |
| category: str | |
| ) -> float: | |
| """ | |
| Verify whether a head's known role is active on the current input. | |
| Args: | |
| attn_matrix: [seq_len, seq_len] attention weights for this head | |
| token_ids: List of token IDs in the input | |
| category: Category name (previous_token, induction, duplicate_token, positional, diffuse) | |
| Returns: | |
| Activation score (0.0 to 1.0). 0.0 means the role is not triggered on this input. | |
| """ | |
| seq_len = attn_matrix.shape[0] | |
| if seq_len < 2: | |
| return 0.0 | |
| if category == "previous_token": | |
| # Mean of diagonal-1 values: how much each position attends to the previous position | |
| prev_token_attentions = [] | |
| for i in range(1, seq_len): | |
| prev_token_attentions.append(attn_matrix[i, i - 1].item()) | |
| return float(np.mean(prev_token_attentions)) if prev_token_attentions else 0.0 | |
| elif category == "induction": | |
| # Induction pattern: [A][B]...[A] → attend to [B] | |
| # For each repeated token at position i where token[i]==token[j] (j < i), | |
| # check if position i attends to position j+1 | |
| repeated = _find_repeated_tokens(token_ids) | |
| if not repeated: | |
| return 0.0 # No repetition → gray out | |
| induction_scores = [] | |
| for tid, positions in repeated.items(): | |
| for k in range(1, len(positions)): | |
| current_pos = positions[k] # Later occurrence | |
| for prev_idx in range(k): | |
| prev_pos = positions[prev_idx] # Earlier occurrence | |
| target_pos = prev_pos + 1 # The token AFTER the earlier occurrence | |
| if target_pos < seq_len and current_pos < seq_len: | |
| induction_scores.append(attn_matrix[current_pos, target_pos].item()) | |
| return float(np.mean(induction_scores)) if induction_scores else 0.0 | |
| elif category == "duplicate_token": | |
| # Check if later occurrences attend to earlier occurrences of the same token | |
| repeated = _find_repeated_tokens(token_ids) | |
| if not repeated: | |
| return 0.0 # No duplicates → gray out | |
| dup_scores = [] | |
| for tid, positions in repeated.items(): | |
| for k in range(1, len(positions)): | |
| later_pos = positions[k] | |
| # Sum attention to all earlier occurrences | |
| earlier_attention = sum( | |
| attn_matrix[later_pos, positions[j]].item() | |
| for j in range(k) | |
| ) | |
| dup_scores.append(earlier_attention) | |
| return float(np.mean(dup_scores)) if dup_scores else 0.0 | |
| elif category == "positional": | |
| # Mean of column-0 attention (how much each position attends to the first token) | |
| first_token_attention = attn_matrix[:, 0].mean().item() | |
| return first_token_attention | |
| elif category == "diffuse": | |
| # Average normalized entropy across all positions | |
| entropies = [] | |
| for i in range(seq_len): | |
| entropies.append(_compute_attention_entropy(attn_matrix[i])) | |
| return float(np.mean(entropies)) if entropies else 0.0 | |
| else: | |
| return 0.0 | |
| def get_active_head_summary( | |
| activation_data: Dict[str, Any], | |
| model_name: str | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Main entry point: load categories from JSON, verify each head on the current input, | |
| and return a UI-ready structure. | |
| Args: | |
| activation_data: Output from execute_forward_pass with attention data | |
| model_name: HuggingFace model name | |
| Returns: | |
| Dict with structure: | |
| { | |
| "model_available": True, | |
| "categories": { | |
| "previous_token": { | |
| "display_name": str, | |
| "description": str, | |
| "educational_text": str, | |
| "icon": str, | |
| "requires_repetition": bool, | |
| "suggested_prompt": str or None, | |
| "is_applicable": bool, # False if requires_repetition but no repeats | |
| "heads": [ | |
| {"layer": int, "head": int, "offline_score": float, | |
| "activation_score": float, "is_active": bool, "label": str} | |
| ] | |
| }, | |
| ... | |
| } | |
| } | |
| Returns None if model not in JSON. | |
| """ | |
| model_data = load_head_categories(model_name) | |
| if model_data is None: | |
| return None | |
| # Extract attention weights and token IDs from activation data | |
| attention_outputs = activation_data.get('attention_outputs', {}) | |
| input_ids = activation_data.get('input_ids', [[]])[0] | |
| if not attention_outputs or not input_ids: | |
| return None | |
| # Build a lookup: (layer, head) → attention_matrix [seq_len, seq_len] | |
| head_attention_lookup: Dict[Tuple[int, int], torch.Tensor] = {} | |
| for module_name, output_dict in attention_outputs.items(): | |
| numbers = re.findall(r'\d+', module_name) | |
| if not numbers: | |
| continue | |
| layer_idx = int(numbers[0]) | |
| attention_output = output_dict.get('output') | |
| if not isinstance(attention_output, list) or len(attention_output) < 2: | |
| continue | |
| # attention_output[1] is [batch, heads, seq_len, seq_len] | |
| attention_weights = torch.tensor(attention_output[1]) | |
| num_heads = attention_weights.shape[1] | |
| for head_idx in range(num_heads): | |
| head_attention_lookup[(layer_idx, head_idx)] = attention_weights[0, head_idx, :, :] | |
| # Check if input has repeated tokens (needed for applicability check) | |
| repeated_tokens = _find_repeated_tokens(input_ids) | |
| has_repetition = len(repeated_tokens) > 0 | |
| # Build result | |
| result = { | |
| "model_available": True, | |
| "categories": {} | |
| } | |
| categories = model_data.get("categories", {}) | |
| # Define category order for consistent display | |
| category_order = ["previous_token", "induction", "duplicate_token", "positional", "diffuse"] | |
| for cat_key in category_order: | |
| cat_info = categories.get(cat_key) | |
| if cat_info is None: | |
| continue | |
| requires_repetition = cat_info.get("requires_repetition", False) | |
| is_applicable = not requires_repetition or has_repetition | |
| heads_result = [] | |
| for head_entry in cat_info.get("top_heads", []): | |
| layer = head_entry["layer"] | |
| head = head_entry["head"] | |
| offline_score = head_entry["score"] | |
| # Get activation score on current input | |
| attn_matrix = head_attention_lookup.get((layer, head)) | |
| if attn_matrix is not None and is_applicable: | |
| activation_score = verify_head_activation(attn_matrix, input_ids, cat_key) | |
| else: | |
| activation_score = 0.0 | |
| # A head is "active" if its activation score exceeds a minimum threshold | |
| is_active = activation_score > 0.1 and is_applicable | |
| heads_result.append({ | |
| "layer": layer, | |
| "head": head, | |
| "offline_score": offline_score, | |
| "activation_score": round(activation_score, 3), | |
| "is_active": is_active, | |
| "label": f"L{layer}-H{head}" | |
| }) | |
| result["categories"][cat_key] = { | |
| "display_name": cat_info.get("display_name", cat_key), | |
| "description": cat_info.get("description", ""), | |
| "educational_text": cat_info.get("educational_text", ""), | |
| "icon": cat_info.get("icon", "circle"), | |
| "requires_repetition": requires_repetition, | |
| "suggested_prompt": cat_info.get("suggested_prompt"), | |
| "is_applicable": is_applicable, | |
| "heads": heads_result | |
| } | |
| # Add "Other" category (heads not claimed by any top list) | |
| result["categories"]["other"] = { | |
| "display_name": "Other / Unclassified", | |
| "description": "Heads whose patterns don't fit the simple categories above", | |
| "educational_text": "This head's pattern doesn't fit our simple categories — it may be doing something more complex or context-dependent.", | |
| "icon": "question-circle", | |
| "requires_repetition": False, | |
| "suggested_prompt": None, | |
| "is_applicable": True, | |
| "heads": [] # We don't enumerate all "other" heads to keep the UI clean | |
| } | |
| return result | |