Spaces:
Sleeping
Sleeping
| """ | |
| Architectural Analysis for RQ1 - Architectural Interpretability | |
| Purpose: Extract and format raw architectural signals for transparency visualization | |
| Focus: Internal mechanisms (NOT post-hoc feature attribution) | |
| Key differences from SHAP/explainability: | |
| - Preserves per-head, per-layer granularity (no aggregation) | |
| - Captures activation patterns and confidence metrics | |
| - Supports causal intervention (ablation) | |
| - Real-time architectural transparency | |
| Based on PhD proposal RQ1: | |
| "Transform opaque architectural mechanisms into interpretable visual representations" | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Optional, Tuple, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def compute_head_entropy(attention_weights: torch.Tensor) -> float: | |
| """ | |
| Compute entropy of attention distribution for a single head. | |
| High entropy = diffuse attention (many tokens attended equally) | |
| Low entropy = focused attention (few tokens dominate) | |
| Args: | |
| attention_weights: [seq_len, seq_len] attention matrix for one head | |
| Returns: | |
| Entropy value (bits) | |
| """ | |
| # Average across query positions to get distribution | |
| avg_dist = attention_weights.mean(dim=0) | |
| # Add small epsilon to avoid log(0) | |
| eps = 1e-10 | |
| avg_dist = avg_dist + eps | |
| # Compute entropy: -sum(p * log(p)) | |
| entropy = -(avg_dist * torch.log2(avg_dist)).sum().item() | |
| # Ensure finite value | |
| entropy = float(np.clip(entropy, 0.0, 1e10)) | |
| if not np.isfinite(entropy): | |
| entropy = 0.0 | |
| return entropy | |
| def identify_head_role(attention_weights: torch.Tensor, tokens: List[str]) -> str: | |
| """ | |
| Classify attention head role based on attention patterns. | |
| Roles: | |
| - 'positional': Attends primarily to specific positions (diagonal, next-token, etc.) | |
| - 'delimiter': Focuses on delimiters/special tokens (braces, semicolons, etc.) | |
| - 'content': Attends to semantic content tokens (identifiers, keywords) | |
| - 'mixed': No clear specialization | |
| Args: | |
| attention_weights: [seq_len, seq_len] | |
| tokens: List of token strings | |
| Returns: | |
| Role classification string | |
| """ | |
| # Compute statistics | |
| diagonal_strength = torch.diag(attention_weights).mean().item() | |
| max_weight = attention_weights.max().item() | |
| # Simple heuristics (can be refined with more research) | |
| if diagonal_strength > 0.3: | |
| return 'positional' | |
| # Check if attends primarily to delimiters | |
| delimiter_tokens = {'{', '}', '(', ')', '[', ']', ';', ',', ':'} | |
| delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens] | |
| if delimiter_indices: | |
| delimiter_attention = attention_weights[:, delimiter_indices].mean().item() | |
| if delimiter_attention > 0.3: | |
| return 'delimiter' | |
| # Check for focused content attention | |
| if max_weight > 0.5: | |
| return 'content' | |
| return 'mixed' | |
| def extract_per_head_attention( | |
| attention_tensor: torch.Tensor, | |
| layer_idx: int, | |
| tokens: List[str] | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Extract per-head attention data for a specific layer. | |
| Args: | |
| attention_tensor: [num_heads, seq_len, seq_len] | |
| layer_idx: Layer index | |
| tokens: Token strings | |
| Returns: | |
| List of dicts, one per head | |
| """ | |
| num_heads = attention_tensor.shape[0] | |
| heads_data = [] | |
| for head_idx in range(num_heads): | |
| head_attn = attention_tensor[head_idx] # [seq_len, seq_len] | |
| # Clean attention matrix - replace NaN/Inf with 0 | |
| head_attn_np = head_attn.cpu().numpy() | |
| head_attn_np = np.nan_to_num(head_attn_np, nan=0.0, posinf=1.0, neginf=0.0) | |
| head_attn_np = np.clip(head_attn_np, 0.0, 1.0) | |
| # Recompute as tensor for entropy/role calculations | |
| head_attn_clean = torch.from_numpy(head_attn_np) | |
| entropy = compute_head_entropy(head_attn_clean) | |
| max_weight = float(head_attn_np.max()) | |
| if not np.isfinite(max_weight): | |
| max_weight = 0.0 | |
| role = identify_head_role(head_attn_clean, tokens) | |
| heads_data.append({ | |
| "head_idx": head_idx, | |
| "attention_matrix": head_attn_np.tolist(), | |
| "entropy": entropy, | |
| "max_weight": max_weight, | |
| "role": role | |
| }) | |
| return heads_data | |
| def compute_activation_metrics( | |
| hidden_states: torch.Tensor, | |
| prev_hidden_states: Optional[torch.Tensor] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute activation-related metrics for a layer. | |
| Args: | |
| hidden_states: [seq_len, hidden_dim] output of layer | |
| prev_hidden_states: Previous layer hidden states (for drift computation) | |
| Returns: | |
| Dict with activation magnitude, entropy, norm, drift | |
| """ | |
| # Activation magnitude: L2 norm averaged across sequence | |
| activation_magnitude = torch.norm(hidden_states, dim=-1).mean().item() | |
| activation_magnitude = float(np.clip(activation_magnitude, -1e10, 1e10)) | |
| if not np.isfinite(activation_magnitude): | |
| activation_magnitude = 0.0 | |
| # Activation entropy: How varied are the activations? | |
| flat_activations = hidden_states.flatten() | |
| # Normalize to probability distribution | |
| probs = torch.softmax(flat_activations, dim=0) | |
| activation_entropy = -(probs * torch.log2(probs + 1e-10)).sum().item() | |
| activation_entropy = float(np.clip(activation_entropy, 0.0, 1e10)) | |
| if not np.isfinite(activation_entropy): | |
| activation_entropy = 0.0 | |
| # Hidden state norm | |
| hidden_state_norm = torch.norm(hidden_states).item() | |
| hidden_state_norm = float(np.clip(hidden_state_norm, -1e10, 1e10)) | |
| if not np.isfinite(hidden_state_norm): | |
| hidden_state_norm = 0.0 | |
| # Hidden state drift (if previous layer available) | |
| hidden_state_drift = None | |
| if prev_hidden_states is not None: | |
| drift = torch.norm(hidden_states - prev_hidden_states).item() | |
| drift = float(np.clip(drift, -1e10, 1e10)) | |
| if np.isfinite(drift): | |
| hidden_state_drift = drift | |
| return { | |
| "activation_magnitude": activation_magnitude, | |
| "activation_entropy": activation_entropy, | |
| "hidden_state_norm": hidden_state_norm, | |
| "hidden_state_drift": hidden_state_drift | |
| } | |
| def extract_architectural_data( | |
| model_outputs: Dict[str, Any], | |
| input_tokens: List[str], | |
| output_tokens: List[str], | |
| model_config: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Extract complete architectural transparency data for visualization. | |
| This is the main function that formats all data needed for | |
| ArchitecturalAttentionExplorer component. | |
| Args: | |
| model_outputs: Dict containing 'attentions', 'hidden_states', etc. | |
| input_tokens: Input token strings | |
| output_tokens: Generated token strings | |
| model_config: Model configuration (num_layers, num_heads, etc.) | |
| Returns: | |
| Complete architectural data dict | |
| """ | |
| # Extract attention from model outputs | |
| # Expected shape: attentions is tuple of [batch, num_heads, seq_len, seq_len] | |
| attentions = model_outputs.get('attentions', None) | |
| hidden_states = model_outputs.get('hidden_states', None) | |
| if attentions is None: | |
| logger.warning("No attention weights in model outputs") | |
| return None | |
| # Process each layer | |
| layers_data = [] | |
| prev_hidden = None | |
| num_layers = len(attentions) | |
| for layer_idx in range(num_layers): | |
| layer_attn = attentions[layer_idx] # [batch, num_heads, seq_len, seq_len] | |
| # Remove batch dimension (assuming batch_size=1) | |
| if layer_attn.dim() == 4: | |
| layer_attn = layer_attn[0] # [num_heads, seq_len, seq_len] | |
| # Extract per-head attention | |
| all_tokens = input_tokens + output_tokens | |
| heads_data = extract_per_head_attention(layer_attn, layer_idx, all_tokens) | |
| # Compute activation metrics | |
| activation_metrics = {"activation_magnitude": 0.0, "activation_entropy": 0.0, "hidden_state_norm": 0.0} | |
| if hidden_states is not None and layer_idx < len(hidden_states): | |
| current_hidden = hidden_states[layer_idx] | |
| if current_hidden.dim() == 3: # [batch, seq_len, hidden_dim] | |
| current_hidden = current_hidden[0] # Remove batch | |
| activation_metrics = compute_activation_metrics(current_hidden, prev_hidden) | |
| prev_hidden = current_hidden | |
| # Combine data for this layer | |
| layer_data = { | |
| "layer_idx": layer_idx, | |
| "attention_heads": heads_data, | |
| **activation_metrics | |
| } | |
| layers_data.append(layer_data) | |
| # Build complete response | |
| architectural_data = { | |
| "layers": layers_data, | |
| "model_info": { | |
| "num_layers": num_layers, | |
| "num_heads": model_config.get('num_heads', len(heads_data)), | |
| "hidden_size": model_config.get('hidden_size', 768), | |
| "model_name": model_config.get('model_name', 'unknown') | |
| }, | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens | |
| } | |
| # Optional: Expert routing (for MoE models) | |
| expert_routing = model_outputs.get('router_logits', None) | |
| if expert_routing is not None: | |
| architectural_data["expert_routing"] = extract_expert_routing(expert_routing) | |
| return architectural_data | |
| def extract_expert_routing(router_logits: torch.Tensor) -> List[Dict[str, Any]]: | |
| """ | |
| Extract expert routing decisions for MoE models. | |
| Args: | |
| router_logits: Router logits from model | |
| Shape depends on model architecture | |
| Returns: | |
| List of routing decisions per layer/token | |
| """ | |
| # This is model-specific and would need to be adapted | |
| # For DeepSeek-MoE, CodeLlama-MoE, etc. | |
| # Placeholder implementation | |
| routing_data = [] | |
| logger.info("Expert routing extraction not yet implemented for this model") | |
| return routing_data | |
| def format_for_study_endpoint( | |
| architectural_data: Dict[str, Any], | |
| generation_metadata: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Format architectural data for /api/study/analyze endpoint response. | |
| Args: | |
| architectural_data: Output from extract_architectural_data() | |
| generation_metadata: Generation stats (time, tokens, etc.) | |
| Returns: | |
| Complete response dict | |
| """ | |
| return { | |
| "architectural_data": architectural_data, | |
| "metadata": generation_metadata, | |
| "visualization_type": "architectural_transparency", | |
| "research_context": "RQ1: Architectural Interpretability" | |
| } | |