Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 10,577 Bytes
37ed739 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
"""
Architectural Analysis for RQ1 - Architectural Interpretability
Purpose: Extract and format raw architectural signals for transparency visualization
Focus: Internal mechanisms (NOT post-hoc feature attribution)
Key differences from SHAP/explainability:
- Preserves per-head, per-layer granularity (no aggregation)
- Captures activation patterns and confidence metrics
- Supports causal intervention (ablation)
- Real-time architectural transparency
Based on PhD proposal RQ1:
"Transform opaque architectural mechanisms into interpretable visual representations"
"""
import torch
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import logging
logger = logging.getLogger(__name__)
def compute_head_entropy(attention_weights: torch.Tensor) -> float:
"""
Compute entropy of attention distribution for a single head.
High entropy = diffuse attention (many tokens attended equally)
Low entropy = focused attention (few tokens dominate)
Args:
attention_weights: [seq_len, seq_len] attention matrix for one head
Returns:
Entropy value (bits)
"""
# Average across query positions to get distribution
avg_dist = attention_weights.mean(dim=0)
# Add small epsilon to avoid log(0)
eps = 1e-10
avg_dist = avg_dist + eps
# Compute entropy: -sum(p * log(p))
entropy = -(avg_dist * torch.log2(avg_dist)).sum().item()
# Ensure finite value
entropy = float(np.clip(entropy, 0.0, 1e10))
if not np.isfinite(entropy):
entropy = 0.0
return entropy
def identify_head_role(attention_weights: torch.Tensor, tokens: List[str]) -> str:
"""
Classify attention head role based on attention patterns.
Roles:
- 'positional': Attends primarily to specific positions (diagonal, next-token, etc.)
- 'delimiter': Focuses on delimiters/special tokens (braces, semicolons, etc.)
- 'content': Attends to semantic content tokens (identifiers, keywords)
- 'mixed': No clear specialization
Args:
attention_weights: [seq_len, seq_len]
tokens: List of token strings
Returns:
Role classification string
"""
# Compute statistics
diagonal_strength = torch.diag(attention_weights).mean().item()
max_weight = attention_weights.max().item()
# Simple heuristics (can be refined with more research)
if diagonal_strength > 0.3:
return 'positional'
# Check if attends primarily to delimiters
delimiter_tokens = {'{', '}', '(', ')', '[', ']', ';', ',', ':'}
delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens]
if delimiter_indices:
delimiter_attention = attention_weights[:, delimiter_indices].mean().item()
if delimiter_attention > 0.3:
return 'delimiter'
# Check for focused content attention
if max_weight > 0.5:
return 'content'
return 'mixed'
def extract_per_head_attention(
attention_tensor: torch.Tensor,
layer_idx: int,
tokens: List[str]
) -> List[Dict[str, Any]]:
"""
Extract per-head attention data for a specific layer.
Args:
attention_tensor: [num_heads, seq_len, seq_len]
layer_idx: Layer index
tokens: Token strings
Returns:
List of dicts, one per head
"""
num_heads = attention_tensor.shape[0]
heads_data = []
for head_idx in range(num_heads):
head_attn = attention_tensor[head_idx] # [seq_len, seq_len]
# Clean attention matrix - replace NaN/Inf with 0
head_attn_np = head_attn.cpu().numpy()
head_attn_np = np.nan_to_num(head_attn_np, nan=0.0, posinf=1.0, neginf=0.0)
head_attn_np = np.clip(head_attn_np, 0.0, 1.0)
# Recompute as tensor for entropy/role calculations
head_attn_clean = torch.from_numpy(head_attn_np)
entropy = compute_head_entropy(head_attn_clean)
max_weight = float(head_attn_np.max())
if not np.isfinite(max_weight):
max_weight = 0.0
role = identify_head_role(head_attn_clean, tokens)
heads_data.append({
"head_idx": head_idx,
"attention_matrix": head_attn_np.tolist(),
"entropy": entropy,
"max_weight": max_weight,
"role": role
})
return heads_data
def compute_activation_metrics(
hidden_states: torch.Tensor,
prev_hidden_states: Optional[torch.Tensor] = None
) -> Dict[str, float]:
"""
Compute activation-related metrics for a layer.
Args:
hidden_states: [seq_len, hidden_dim] output of layer
prev_hidden_states: Previous layer hidden states (for drift computation)
Returns:
Dict with activation magnitude, entropy, norm, drift
"""
# Activation magnitude: L2 norm averaged across sequence
activation_magnitude = torch.norm(hidden_states, dim=-1).mean().item()
activation_magnitude = float(np.clip(activation_magnitude, -1e10, 1e10))
if not np.isfinite(activation_magnitude):
activation_magnitude = 0.0
# Activation entropy: How varied are the activations?
flat_activations = hidden_states.flatten()
# Normalize to probability distribution
probs = torch.softmax(flat_activations, dim=0)
activation_entropy = -(probs * torch.log2(probs + 1e-10)).sum().item()
activation_entropy = float(np.clip(activation_entropy, 0.0, 1e10))
if not np.isfinite(activation_entropy):
activation_entropy = 0.0
# Hidden state norm
hidden_state_norm = torch.norm(hidden_states).item()
hidden_state_norm = float(np.clip(hidden_state_norm, -1e10, 1e10))
if not np.isfinite(hidden_state_norm):
hidden_state_norm = 0.0
# Hidden state drift (if previous layer available)
hidden_state_drift = None
if prev_hidden_states is not None:
drift = torch.norm(hidden_states - prev_hidden_states).item()
drift = float(np.clip(drift, -1e10, 1e10))
if np.isfinite(drift):
hidden_state_drift = drift
return {
"activation_magnitude": activation_magnitude,
"activation_entropy": activation_entropy,
"hidden_state_norm": hidden_state_norm,
"hidden_state_drift": hidden_state_drift
}
def extract_architectural_data(
model_outputs: Dict[str, Any],
input_tokens: List[str],
output_tokens: List[str],
model_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Extract complete architectural transparency data for visualization.
This is the main function that formats all data needed for
ArchitecturalAttentionExplorer component.
Args:
model_outputs: Dict containing 'attentions', 'hidden_states', etc.
input_tokens: Input token strings
output_tokens: Generated token strings
model_config: Model configuration (num_layers, num_heads, etc.)
Returns:
Complete architectural data dict
"""
# Extract attention from model outputs
# Expected shape: attentions is tuple of [batch, num_heads, seq_len, seq_len]
attentions = model_outputs.get('attentions', None)
hidden_states = model_outputs.get('hidden_states', None)
if attentions is None:
logger.warning("No attention weights in model outputs")
return None
# Process each layer
layers_data = []
prev_hidden = None
num_layers = len(attentions)
for layer_idx in range(num_layers):
layer_attn = attentions[layer_idx] # [batch, num_heads, seq_len, seq_len]
# Remove batch dimension (assuming batch_size=1)
if layer_attn.dim() == 4:
layer_attn = layer_attn[0] # [num_heads, seq_len, seq_len]
# Extract per-head attention
all_tokens = input_tokens + output_tokens
heads_data = extract_per_head_attention(layer_attn, layer_idx, all_tokens)
# Compute activation metrics
activation_metrics = {"activation_magnitude": 0.0, "activation_entropy": 0.0, "hidden_state_norm": 0.0}
if hidden_states is not None and layer_idx < len(hidden_states):
current_hidden = hidden_states[layer_idx]
if current_hidden.dim() == 3: # [batch, seq_len, hidden_dim]
current_hidden = current_hidden[0] # Remove batch
activation_metrics = compute_activation_metrics(current_hidden, prev_hidden)
prev_hidden = current_hidden
# Combine data for this layer
layer_data = {
"layer_idx": layer_idx,
"attention_heads": heads_data,
**activation_metrics
}
layers_data.append(layer_data)
# Build complete response
architectural_data = {
"layers": layers_data,
"model_info": {
"num_layers": num_layers,
"num_heads": model_config.get('num_heads', len(heads_data)),
"hidden_size": model_config.get('hidden_size', 768),
"model_name": model_config.get('model_name', 'unknown')
},
"input_tokens": input_tokens,
"output_tokens": output_tokens
}
# Optional: Expert routing (for MoE models)
expert_routing = model_outputs.get('router_logits', None)
if expert_routing is not None:
architectural_data["expert_routing"] = extract_expert_routing(expert_routing)
return architectural_data
def extract_expert_routing(router_logits: torch.Tensor) -> List[Dict[str, Any]]:
"""
Extract expert routing decisions for MoE models.
Args:
router_logits: Router logits from model
Shape depends on model architecture
Returns:
List of routing decisions per layer/token
"""
# This is model-specific and would need to be adapted
# For DeepSeek-MoE, CodeLlama-MoE, etc.
# Placeholder implementation
routing_data = []
logger.info("Expert routing extraction not yet implemented for this model")
return routing_data
def format_for_study_endpoint(
architectural_data: Dict[str, Any],
generation_metadata: Dict[str, Any]
) -> Dict[str, Any]:
"""
Format architectural data for /api/study/analyze endpoint response.
Args:
architectural_data: Output from extract_architectural_data()
generation_metadata: Generation stats (time, tokens, etc.)
Returns:
Complete response dict
"""
return {
"architectural_data": architectural_data,
"metadata": generation_metadata,
"visualization_type": "architectural_transparency",
"research_context": "RQ1: Architectural Interpretability"
}
|