Spaces:
Sleeping
Sleeping
| """ | |
| Q, K, V Matrix Extractor for Attention Mechanism Visualization | |
| Extracts Query, Key, and Value matrices from transformer attention layers | |
| along with attention scores and token embeddings for deep visualization. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional, Any | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class QKVData: | |
| """Stores Q, K, V matrices and attention data for a single head""" | |
| layer: int | |
| head: int | |
| query: np.ndarray # [seq_len, head_dim] | |
| key: np.ndarray # [seq_len, head_dim] | |
| value: np.ndarray # [seq_len, head_dim] | |
| attention_scores_raw: np.ndarray # [seq_len, seq_len] before softmax | |
| attention_weights: np.ndarray # [seq_len, seq_len] after softmax | |
| head_dim: int | |
| class TokenEmbedding: | |
| """Token embedding at a specific layer""" | |
| token: str | |
| token_id: int | |
| position: int | |
| layer: int | |
| embedding: np.ndarray # Full embedding vector | |
| embedding_2d: Tuple[float, float] # Reduced to 2D for visualization | |
| embedding_3d: Tuple[float, float, float] # Reduced to 3D for visualization | |
| class AttentionAnalysis: | |
| """Complete attention analysis for a sequence""" | |
| tokens: List[str] | |
| token_ids: List[int] | |
| qkv_data: List[QKVData] # QKV for each layer/head | |
| token_embeddings: List[TokenEmbedding] # Embeddings at each layer | |
| positional_encodings: Optional[np.ndarray] | |
| layer_count: int | |
| head_count: int | |
| sequence_length: int | |
| model_dimension: int | |
| class QKVExtractor: | |
| """Extracts Q, K, V matrices and attention patterns from transformer models""" | |
| def __init__(self, model, tokenizer, adapter=None): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.adapter = adapter # ModelAdapter for accessing Q/K/V projections | |
| self.device = next(model.parameters()).device | |
| # Storage for extracted data | |
| self.qkv_data = [] | |
| self.embeddings = [] | |
| self.handles = [] | |
| # Storage for Q/K/V projections from hooks | |
| self.layer_qkv_outputs = {} # {layer_idx: {'Q': tensor, 'K': tensor, 'V': tensor}} | |
| # Get model configuration - ALWAYS use adapter if available | |
| if adapter: | |
| self.n_layers = adapter.get_num_layers() | |
| self.n_heads = adapter.get_num_heads() | |
| self.d_model = adapter.model_dimension | |
| self.head_dim = self.d_model // self.n_heads | |
| self.n_kv_heads = adapter.get_num_kv_heads() | |
| else: | |
| # Fallback to model attributes (CodeGen style) | |
| if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'): | |
| self.n_layers = len(model.transformer.h) | |
| else: | |
| self.n_layers = 12 | |
| self.n_heads = model.config.n_head if hasattr(model.config, 'n_head') else 16 | |
| self.d_model = model.config.n_embd if hasattr(model.config, 'n_embd') else 768 | |
| self.head_dim = self.d_model // self.n_heads | |
| self.n_kv_heads = None | |
| def register_hooks(self): | |
| """Register hooks to capture Q, K, V matrices""" | |
| self.clear_hooks() | |
| self.layer_qkv_outputs = {} | |
| if not self.adapter: | |
| logger.warning("No adapter provided - cannot extract real Q/K/V matrices") | |
| return | |
| # Hook into each transformer layer | |
| for layer_idx in range(self.n_layers): | |
| try: | |
| # Get Q, K, V projection modules | |
| q_proj, k_proj, v_proj = self.adapter.get_qkv_projections(layer_idx) | |
| # Initialize storage for this layer | |
| self.layer_qkv_outputs[layer_idx] = {'Q': None, 'K': None, 'V': None, 'combined': None} | |
| # Check if this is a combined QKV projection (CodeGen) | |
| # If all three point to the same module, it's a combined projection | |
| is_combined = (q_proj is k_proj) and (k_proj is v_proj) and (q_proj is not None) | |
| if is_combined: | |
| # Hook the combined QKV projection once | |
| combined_handle = q_proj.register_forward_hook( | |
| lambda module, input, output, l_idx=layer_idx: | |
| self._combined_qkv_hook(module, input, output, l_idx) | |
| ) | |
| self.handles.append(combined_handle) | |
| else: | |
| # Hook Q, K, V projections separately (LLaMA style) | |
| if q_proj is not None: | |
| q_handle = q_proj.register_forward_hook( | |
| lambda module, input, output, l_idx=layer_idx: | |
| self._q_proj_hook(module, input, output, l_idx) | |
| ) | |
| self.handles.append(q_handle) | |
| if k_proj is not None: | |
| k_handle = k_proj.register_forward_hook( | |
| lambda module, input, output, l_idx=layer_idx: | |
| self._k_proj_hook(module, input, output, l_idx) | |
| ) | |
| self.handles.append(k_handle) | |
| if v_proj is not None: | |
| v_handle = v_proj.register_forward_hook( | |
| lambda module, input, output, l_idx=layer_idx: | |
| self._v_proj_hook(module, input, output, l_idx) | |
| ) | |
| self.handles.append(v_handle) | |
| # Hook to capture embeddings after each layer | |
| layer_module = self.adapter.get_layer_module(layer_idx) | |
| layer_handle = layer_module.register_forward_hook( | |
| lambda module, input, output, l_idx=layer_idx: | |
| self._embedding_hook(module, input, output, l_idx) | |
| ) | |
| self.handles.append(layer_handle) | |
| except Exception as e: | |
| logger.warning(f"Failed to register hooks for layer {layer_idx}: {e}") | |
| logger.info(f"Registered {len(self.handles)} hooks for QKV extraction") | |
| def _combined_qkv_hook(self, module, input, output, layer_idx): | |
| """Hook to capture combined QKV projection output (CodeGen style)""" | |
| try: | |
| # Store the combined QKV output | |
| # Output shape: [batch, seq_len, 3 * n_heads * head_dim] | |
| # We'll split it in _process_qkv_data | |
| if layer_idx in self.layer_qkv_outputs: | |
| self.layer_qkv_outputs[layer_idx]['combined'] = output.detach() | |
| logger.info(f"Captured combined QKV at layer {layer_idx}, shape={output.shape}") | |
| except Exception as e: | |
| logger.warning(f"Failed to capture combined QKV at layer {layer_idx}: {e}") | |
| def _q_proj_hook(self, module, input, output, layer_idx): | |
| """Hook to capture Query projection output""" | |
| try: | |
| # Store the Q projection output | |
| # Output shape: [batch, seq_len, n_heads * head_dim] | |
| if layer_idx in self.layer_qkv_outputs: | |
| self.layer_qkv_outputs[layer_idx]['Q'] = output.detach() | |
| except Exception as e: | |
| logger.warning(f"Failed to capture Q at layer {layer_idx}: {e}") | |
| def _k_proj_hook(self, module, input, output, layer_idx): | |
| """Hook to capture Key projection output""" | |
| try: | |
| # Store the K projection output | |
| # Output shape: [batch, seq_len, n_kv_heads * head_dim] (for GQA) or [batch, seq_len, n_heads * head_dim] (for MHA) | |
| if layer_idx in self.layer_qkv_outputs: | |
| self.layer_qkv_outputs[layer_idx]['K'] = output.detach() | |
| except Exception as e: | |
| logger.warning(f"Failed to capture K at layer {layer_idx}: {e}") | |
| def _v_proj_hook(self, module, input, output, layer_idx): | |
| """Hook to capture Value projection output""" | |
| try: | |
| # Store the V projection output | |
| # Output shape: [batch, seq_len, n_kv_heads * head_dim] (for GQA) or [batch, seq_len, n_heads * head_dim] (for MHA) | |
| if layer_idx in self.layer_qkv_outputs: | |
| self.layer_qkv_outputs[layer_idx]['V'] = output.detach() | |
| except Exception as e: | |
| logger.warning(f"Failed to capture V at layer {layer_idx}: {e}") | |
| def _embedding_hook(self, module, input, output, layer_idx): | |
| """Hook to capture token embeddings after each layer""" | |
| try: | |
| # Output is the hidden states after this layer | |
| if isinstance(output, tuple): | |
| hidden_states = output[0] | |
| else: | |
| hidden_states = output | |
| # Store embeddings [batch, seq_len, d_model] | |
| embeddings = hidden_states[0].detach().cpu().numpy() # Take first batch | |
| self.embeddings.append({ | |
| 'layer': layer_idx, | |
| 'embeddings': embeddings | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract embeddings at layer {layer_idx}: {e}") | |
| def _process_qkv_data(self, attention_outputs): | |
| """ | |
| Process captured Q/K/V tensors and combine with attention weights | |
| Args: | |
| attention_outputs: Attention tensors from model.output_attentions | |
| """ | |
| if not attention_outputs: | |
| logger.warning("No attention outputs available") | |
| return | |
| for layer_idx in range(self.n_layers): | |
| try: | |
| # Get captured Q/K/V for this layer | |
| if layer_idx not in self.layer_qkv_outputs: | |
| continue | |
| qkv = self.layer_qkv_outputs[layer_idx] | |
| # Check if we have combined QKV (CodeGen) or separate Q/K/V (LLaMA) | |
| if qkv['combined'] is not None: | |
| # Combined QKV projection - split it | |
| combined = qkv['combined'] # [batch, seq_len, 3 * n_heads * head_dim] | |
| batch_size, seq_len, _ = combined.shape | |
| logger.info(f"Layer {layer_idx}: Using combined QKV, shape={combined.shape}") | |
| # Split into Q, K, V | |
| # Each is [batch, seq_len, n_heads * head_dim] | |
| qkv_dim = self.n_heads * self.head_dim | |
| Q = combined[:, :, 0:qkv_dim] | |
| K = combined[:, :, qkv_dim:2*qkv_dim] | |
| V = combined[:, :, 2*qkv_dim:3*qkv_dim] | |
| logger.info(f"Layer {layer_idx}: Split Q={Q.shape}, K={K.shape}, V={V.shape}") | |
| else: | |
| # Separate projections | |
| Q = qkv['Q'] # [batch, seq_len, n_heads * head_dim] | |
| K = qkv['K'] # [batch, seq_len, n_kv_heads * head_dim] | |
| V = qkv['V'] # [batch, seq_len, n_kv_heads * head_dim] | |
| logger.info(f"Layer {layer_idx}: Using separate Q/K/V, Q={Q.shape if Q is not None else None}") | |
| if Q is None or K is None or V is None: | |
| continue | |
| # Get attention weights for this layer | |
| attn_weights = attention_outputs[layer_idx] # [batch, n_heads, seq_len, seq_len] | |
| batch_size, seq_len, _ = Q.shape | |
| # Reshape Q: [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim] | |
| Q_reshaped = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| # For K and V, handle GQA | |
| if self.n_kv_heads is not None: | |
| # GQA: replicate KV heads to match Q heads | |
| kv_head_dim = K.shape[-1] // self.n_kv_heads | |
| # Reshape K/V: [batch, seq_len, n_kv_heads, head_dim] | |
| K_reshaped = K.view(batch_size, seq_len, self.n_kv_heads, kv_head_dim).transpose(1, 2) | |
| V_reshaped = V.view(batch_size, seq_len, self.n_kv_heads, kv_head_dim).transpose(1, 2) | |
| # Replicate to match n_heads | |
| repeat_factor = self.n_heads // self.n_kv_heads | |
| K_reshaped = K_reshaped.repeat_interleave(repeat_factor, dim=1) | |
| V_reshaped = V_reshaped.repeat_interleave(repeat_factor, dim=1) | |
| else: | |
| # Standard MHA | |
| K_reshaped = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| V_reshaped = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| # Now Q, K, V are all [batch, n_heads, seq_len, head_dim] | |
| # Convert to numpy and take first batch | |
| Q_np = Q_reshaped[0].cpu().numpy() # [n_heads, seq_len, head_dim] | |
| K_np = K_reshaped[0].cpu().numpy() | |
| V_np = V_reshaped[0].cpu().numpy() | |
| attn_np = attn_weights[0].cpu().numpy() # [n_heads, seq_len, seq_len] | |
| # Sample every 4th head to reduce data volume | |
| for head_idx in range(0, self.n_heads, 4): | |
| # Extract Q/K/V for this head | |
| q_head = Q_np[head_idx] # [seq_len, head_dim] | |
| k_head = K_np[head_idx] # [seq_len, head_dim] | |
| v_head = V_np[head_idx] # [seq_len, head_dim] | |
| attn_head = attn_np[head_idx] # [seq_len, seq_len] | |
| # Compute raw attention scores from Q·K^T / sqrt(d_k) | |
| # This is what the model computes before softmax | |
| scale = np.sqrt(self.head_dim) | |
| attn_scores_raw = (q_head @ k_head.T) / scale | |
| qkv_data = QKVData( | |
| layer=layer_idx, | |
| head=head_idx, | |
| query=q_head, | |
| key=k_head, | |
| value=v_head, | |
| attention_scores_raw=attn_scores_raw, | |
| attention_weights=attn_head, | |
| head_dim=self.head_dim | |
| ) | |
| self.qkv_data.append(qkv_data) | |
| logger.info(f"Processed real Q/K/V data for layer {layer_idx}") | |
| except Exception as e: | |
| logger.warning(f"Failed to process QKV data at layer {layer_idx}: {e}") | |
| import traceback | |
| logger.warning(traceback.format_exc()) | |
| def clear_hooks(self): | |
| """Remove all hooks""" | |
| for handle in self.handles: | |
| handle.remove() | |
| self.handles = [] | |
| # Don't clear data here - we need it for the return value! | |
| def extract_attention_data(self, text: str) -> AttentionAnalysis: | |
| """ | |
| Extract complete attention analysis for input text | |
| Args: | |
| text: Input text to analyze | |
| Returns: | |
| AttentionAnalysis object with all extracted data | |
| """ | |
| # Tokenize input | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True) | |
| input_ids = inputs["input_ids"].to(self.device) | |
| # Get tokens | |
| tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]] | |
| token_ids = input_ids[0].tolist() | |
| # Register hooks and run forward pass | |
| self.register_hooks() | |
| self.qkv_data = [] | |
| self.embeddings = [] | |
| try: | |
| with torch.no_grad(): | |
| # Forward pass to trigger hooks - MUST request attention outputs | |
| outputs = self.model( | |
| input_ids, | |
| output_hidden_states=True, | |
| output_attentions=True # Critical for getting attention weights | |
| ) | |
| # Process captured Q/K/V data with attention weights | |
| if hasattr(outputs, 'attentions') and outputs.attentions: | |
| self._process_qkv_data(outputs.attentions) | |
| logger.info(f"Extracted {len(self.qkv_data)} QKV data points") | |
| else: | |
| logger.warning("No attention outputs available - cannot extract Q/K/V") | |
| # Get initial embeddings (before any layers) | |
| positional_encodings = None | |
| if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'wte'): | |
| initial_embeddings = self.model.transformer.wte(input_ids) | |
| # Add positional encodings if available | |
| if hasattr(self.model.transformer, 'wpe'): | |
| positions = torch.arange(0, input_ids.shape[1], device=self.device) | |
| positional_encodings = self.model.transformer.wpe(positions) | |
| positional_encodings = positional_encodings.detach().cpu().numpy() | |
| finally: | |
| self.clear_hooks() | |
| # Process token embeddings with dimensionality reduction | |
| token_embeddings = self._process_embeddings(tokens, token_ids) | |
| return AttentionAnalysis( | |
| tokens=tokens, | |
| token_ids=token_ids, | |
| qkv_data=self.qkv_data, | |
| token_embeddings=token_embeddings, | |
| positional_encodings=positional_encodings[0] if positional_encodings is not None else None, | |
| layer_count=self.n_layers, | |
| head_count=self.n_heads, | |
| sequence_length=len(tokens), | |
| model_dimension=self.d_model | |
| ) | |
| def _process_embeddings(self, tokens: List[str], token_ids: List[int]) -> List[TokenEmbedding]: | |
| """Process and reduce dimensionality of embeddings for visualization""" | |
| token_embeddings = [] | |
| for emb_data in self.embeddings: | |
| layer = emb_data['layer'] | |
| embeddings = emb_data['embeddings'] # [seq_len, d_model] | |
| for pos, (token, token_id, embedding) in enumerate(zip(tokens, token_ids, embeddings)): | |
| # Reduce to 2D using PCA-like projection (simplified) | |
| # In production, use sklearn PCA or t-SNE | |
| embedding_2d = ( | |
| float(np.mean(embedding[:self.d_model//2])), | |
| float(np.mean(embedding[self.d_model//2:])) | |
| ) | |
| # Reduce to 3D | |
| third = self.d_model // 3 | |
| embedding_3d = ( | |
| float(np.mean(embedding[:third])), | |
| float(np.mean(embedding[third:2*third])), | |
| float(np.mean(embedding[2*third:])) | |
| ) | |
| token_embeddings.append(TokenEmbedding( | |
| token=token, | |
| token_id=token_id, | |
| position=pos, | |
| layer=layer, | |
| embedding=embedding, | |
| embedding_2d=embedding_2d, | |
| embedding_3d=embedding_3d | |
| )) | |
| return token_embeddings | |
| def get_attention_flow(self, analysis: AttentionAnalysis, | |
| source_token: int, | |
| layer: Optional[int] = None) -> Dict[str, Any]: | |
| """ | |
| Get attention flow from a specific token across layers/heads | |
| Args: | |
| analysis: AttentionAnalysis object | |
| source_token: Token position to analyze | |
| layer: Specific layer to analyze (None for all layers) | |
| Returns: | |
| Dictionary with attention flow data | |
| """ | |
| flow_data = { | |
| 'source_token': analysis.tokens[source_token], | |
| 'source_position': source_token, | |
| 'attention_targets': [] | |
| } | |
| # Filter QKV data by layer if specified | |
| qkv_subset = [q for q in analysis.qkv_data if layer is None or q.layer == layer] | |
| for qkv in qkv_subset: | |
| # Get attention from source token to all other tokens | |
| attention_from_source = qkv.attention_weights[source_token, :] | |
| # Find top attended tokens | |
| top_k = min(5, len(attention_from_source)) | |
| top_indices = np.argsort(attention_from_source)[-top_k:][::-1] | |
| for target_idx in top_indices: | |
| flow_data['attention_targets'].append({ | |
| 'layer': qkv.layer, | |
| 'head': qkv.head, | |
| 'target_position': int(target_idx), | |
| 'target_token': analysis.tokens[target_idx], | |
| 'attention_weight': float(attention_from_source[target_idx]) | |
| }) | |
| return flow_data |