api / backend /architectural_analysis.py
gary-boon
Add research attention analysis endpoints with Q/K/V extraction
37ed739
raw
history blame
10.6 kB
"""
Architectural Analysis for RQ1 - Architectural Interpretability
Purpose: Extract and format raw architectural signals for transparency visualization
Focus: Internal mechanisms (NOT post-hoc feature attribution)
Key differences from SHAP/explainability:
- Preserves per-head, per-layer granularity (no aggregation)
- Captures activation patterns and confidence metrics
- Supports causal intervention (ablation)
- Real-time architectural transparency
Based on PhD proposal RQ1:
"Transform opaque architectural mechanisms into interpretable visual representations"
"""
import torch
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import logging
logger = logging.getLogger(__name__)
def compute_head_entropy(attention_weights: torch.Tensor) -> float:
"""
Compute entropy of attention distribution for a single head.
High entropy = diffuse attention (many tokens attended equally)
Low entropy = focused attention (few tokens dominate)
Args:
attention_weights: [seq_len, seq_len] attention matrix for one head
Returns:
Entropy value (bits)
"""
# Average across query positions to get distribution
avg_dist = attention_weights.mean(dim=0)
# Add small epsilon to avoid log(0)
eps = 1e-10
avg_dist = avg_dist + eps
# Compute entropy: -sum(p * log(p))
entropy = -(avg_dist * torch.log2(avg_dist)).sum().item()
# Ensure finite value
entropy = float(np.clip(entropy, 0.0, 1e10))
if not np.isfinite(entropy):
entropy = 0.0
return entropy
def identify_head_role(attention_weights: torch.Tensor, tokens: List[str]) -> str:
"""
Classify attention head role based on attention patterns.
Roles:
- 'positional': Attends primarily to specific positions (diagonal, next-token, etc.)
- 'delimiter': Focuses on delimiters/special tokens (braces, semicolons, etc.)
- 'content': Attends to semantic content tokens (identifiers, keywords)
- 'mixed': No clear specialization
Args:
attention_weights: [seq_len, seq_len]
tokens: List of token strings
Returns:
Role classification string
"""
# Compute statistics
diagonal_strength = torch.diag(attention_weights).mean().item()
max_weight = attention_weights.max().item()
# Simple heuristics (can be refined with more research)
if diagonal_strength > 0.3:
return 'positional'
# Check if attends primarily to delimiters
delimiter_tokens = {'{', '}', '(', ')', '[', ']', ';', ',', ':'}
delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens]
if delimiter_indices:
delimiter_attention = attention_weights[:, delimiter_indices].mean().item()
if delimiter_attention > 0.3:
return 'delimiter'
# Check for focused content attention
if max_weight > 0.5:
return 'content'
return 'mixed'
def extract_per_head_attention(
attention_tensor: torch.Tensor,
layer_idx: int,
tokens: List[str]
) -> List[Dict[str, Any]]:
"""
Extract per-head attention data for a specific layer.
Args:
attention_tensor: [num_heads, seq_len, seq_len]
layer_idx: Layer index
tokens: Token strings
Returns:
List of dicts, one per head
"""
num_heads = attention_tensor.shape[0]
heads_data = []
for head_idx in range(num_heads):
head_attn = attention_tensor[head_idx] # [seq_len, seq_len]
# Clean attention matrix - replace NaN/Inf with 0
head_attn_np = head_attn.cpu().numpy()
head_attn_np = np.nan_to_num(head_attn_np, nan=0.0, posinf=1.0, neginf=0.0)
head_attn_np = np.clip(head_attn_np, 0.0, 1.0)
# Recompute as tensor for entropy/role calculations
head_attn_clean = torch.from_numpy(head_attn_np)
entropy = compute_head_entropy(head_attn_clean)
max_weight = float(head_attn_np.max())
if not np.isfinite(max_weight):
max_weight = 0.0
role = identify_head_role(head_attn_clean, tokens)
heads_data.append({
"head_idx": head_idx,
"attention_matrix": head_attn_np.tolist(),
"entropy": entropy,
"max_weight": max_weight,
"role": role
})
return heads_data
def compute_activation_metrics(
hidden_states: torch.Tensor,
prev_hidden_states: Optional[torch.Tensor] = None
) -> Dict[str, float]:
"""
Compute activation-related metrics for a layer.
Args:
hidden_states: [seq_len, hidden_dim] output of layer
prev_hidden_states: Previous layer hidden states (for drift computation)
Returns:
Dict with activation magnitude, entropy, norm, drift
"""
# Activation magnitude: L2 norm averaged across sequence
activation_magnitude = torch.norm(hidden_states, dim=-1).mean().item()
activation_magnitude = float(np.clip(activation_magnitude, -1e10, 1e10))
if not np.isfinite(activation_magnitude):
activation_magnitude = 0.0
# Activation entropy: How varied are the activations?
flat_activations = hidden_states.flatten()
# Normalize to probability distribution
probs = torch.softmax(flat_activations, dim=0)
activation_entropy = -(probs * torch.log2(probs + 1e-10)).sum().item()
activation_entropy = float(np.clip(activation_entropy, 0.0, 1e10))
if not np.isfinite(activation_entropy):
activation_entropy = 0.0
# Hidden state norm
hidden_state_norm = torch.norm(hidden_states).item()
hidden_state_norm = float(np.clip(hidden_state_norm, -1e10, 1e10))
if not np.isfinite(hidden_state_norm):
hidden_state_norm = 0.0
# Hidden state drift (if previous layer available)
hidden_state_drift = None
if prev_hidden_states is not None:
drift = torch.norm(hidden_states - prev_hidden_states).item()
drift = float(np.clip(drift, -1e10, 1e10))
if np.isfinite(drift):
hidden_state_drift = drift
return {
"activation_magnitude": activation_magnitude,
"activation_entropy": activation_entropy,
"hidden_state_norm": hidden_state_norm,
"hidden_state_drift": hidden_state_drift
}
def extract_architectural_data(
model_outputs: Dict[str, Any],
input_tokens: List[str],
output_tokens: List[str],
model_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Extract complete architectural transparency data for visualization.
This is the main function that formats all data needed for
ArchitecturalAttentionExplorer component.
Args:
model_outputs: Dict containing 'attentions', 'hidden_states', etc.
input_tokens: Input token strings
output_tokens: Generated token strings
model_config: Model configuration (num_layers, num_heads, etc.)
Returns:
Complete architectural data dict
"""
# Extract attention from model outputs
# Expected shape: attentions is tuple of [batch, num_heads, seq_len, seq_len]
attentions = model_outputs.get('attentions', None)
hidden_states = model_outputs.get('hidden_states', None)
if attentions is None:
logger.warning("No attention weights in model outputs")
return None
# Process each layer
layers_data = []
prev_hidden = None
num_layers = len(attentions)
for layer_idx in range(num_layers):
layer_attn = attentions[layer_idx] # [batch, num_heads, seq_len, seq_len]
# Remove batch dimension (assuming batch_size=1)
if layer_attn.dim() == 4:
layer_attn = layer_attn[0] # [num_heads, seq_len, seq_len]
# Extract per-head attention
all_tokens = input_tokens + output_tokens
heads_data = extract_per_head_attention(layer_attn, layer_idx, all_tokens)
# Compute activation metrics
activation_metrics = {"activation_magnitude": 0.0, "activation_entropy": 0.0, "hidden_state_norm": 0.0}
if hidden_states is not None and layer_idx < len(hidden_states):
current_hidden = hidden_states[layer_idx]
if current_hidden.dim() == 3: # [batch, seq_len, hidden_dim]
current_hidden = current_hidden[0] # Remove batch
activation_metrics = compute_activation_metrics(current_hidden, prev_hidden)
prev_hidden = current_hidden
# Combine data for this layer
layer_data = {
"layer_idx": layer_idx,
"attention_heads": heads_data,
**activation_metrics
}
layers_data.append(layer_data)
# Build complete response
architectural_data = {
"layers": layers_data,
"model_info": {
"num_layers": num_layers,
"num_heads": model_config.get('num_heads', len(heads_data)),
"hidden_size": model_config.get('hidden_size', 768),
"model_name": model_config.get('model_name', 'unknown')
},
"input_tokens": input_tokens,
"output_tokens": output_tokens
}
# Optional: Expert routing (for MoE models)
expert_routing = model_outputs.get('router_logits', None)
if expert_routing is not None:
architectural_data["expert_routing"] = extract_expert_routing(expert_routing)
return architectural_data
def extract_expert_routing(router_logits: torch.Tensor) -> List[Dict[str, Any]]:
"""
Extract expert routing decisions for MoE models.
Args:
router_logits: Router logits from model
Shape depends on model architecture
Returns:
List of routing decisions per layer/token
"""
# This is model-specific and would need to be adapted
# For DeepSeek-MoE, CodeLlama-MoE, etc.
# Placeholder implementation
routing_data = []
logger.info("Expert routing extraction not yet implemented for this model")
return routing_data
def format_for_study_endpoint(
architectural_data: Dict[str, Any],
generation_metadata: Dict[str, Any]
) -> Dict[str, Any]:
"""
Format architectural data for /api/study/analyze endpoint response.
Args:
architectural_data: Output from extract_architectural_data()
generation_metadata: Generation stats (time, tokens, etc.)
Returns:
Complete response dict
"""
return {
"architectural_data": architectural_data,
"metadata": generation_metadata,
"visualization_type": "architectural_transparency",
"research_context": "RQ1: Architectural Interpretability"
}