""" Instrumentation layer for capturing model internals during generation. Designed for PhD study on architectural transparency. Captures: - Attention tensors A[L,H,T,T] per layer/head - Residual norms ||x_l|| per layer - Logits, logprobs, entropy per token - Timing per layer """ import torch import numpy as np from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field from datetime import datetime import time import logging logger = logging.getLogger(__name__) @dataclass class TokenMetadata: """Metadata for a single generated token""" token_id: int text: str position: int logprob: float entropy: float top_k_tokens: List[Tuple[str, float]] # (token_text, probability) byte_length: int timestamp_ms: float @dataclass class LayerMetadata: """Metadata captured per layer during forward pass""" layer_idx: int residual_norm: float time_ms: float attention_output_norm: Optional[float] = None ffn_output_norm: Optional[float] = None @dataclass class InstrumentationData: """Complete instrumentation capture for a generation run""" # Run identification run_id: str seed: int model_name: str timestamp: float # Generation parameters prompt: str max_tokens: int temperature: float top_k: Optional[int] top_p: Optional[float] # Token-level data tokens: List[TokenMetadata] = field(default_factory=list) # Tensor data (will be stored separately in Zarr) attention_tensors: Optional[torch.Tensor] = None # [num_tokens, num_layers, num_heads, seq_len, seq_len] logits_history: Optional[torch.Tensor] = None # [num_tokens, vocab_size] # Layer-level metadata layer_metadata: List[List[LayerMetadata]] = field(default_factory=list) # [num_tokens][num_layers] # Summary statistics total_time_ms: float = 0.0 num_layers: int = 0 num_heads: int = 0 seq_length: int = 0 class ModelInstrumentor: """ Attaches PyTorch hooks to capture model internals during generation. Usage: instrumentor = ModelInstrumentor(model, tokenizer) with instrumentor.capture(): outputs = model.generate(...) data = instrumentor.get_data() """ def __init__(self, model, tokenizer, device): self.model = model self.tokenizer = tokenizer self.device = device # Hook handles (for cleanup) self.hook_handles = [] # Capture buffers self.attention_buffer = [] self.residual_buffer = [] self.timing_buffer = [] self.logits_buffer = [] # Metadata self.config = model.config self.num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'n_layer', 0)) self.num_heads = getattr(self.config, 'num_attention_heads', getattr(self.config, 'n_head', 0)) # State self.capturing = False self.start_time = None def _create_attention_hook(self, layer_idx: int): """ Create forward hook to capture attention weights for a specific layer. Attention outputs vary by model: - GPT-2/CodeGen: (attention_weights, present_key_value) - Llama: (hidden_states, attention_weights, ...) We extract the attention_weights tensor which has shape: [batch_size, num_heads, seq_len, seq_len] """ def hook(module, input, output): if not self.capturing: return start_time = time.perf_counter() try: # Extract attention weights from output # For most models, attention_weights is the second element if isinstance(output, tuple) and len(output) >= 2: attention_weights = output[1] if attention_weights is not None and torch.is_tensor(attention_weights): # Store attention weights # Shape: [batch_size, num_heads, seq_len, seq_len] self.attention_buffer.append({ 'layer_idx': layer_idx, 'weights': attention_weights.detach().cpu(), 'timestamp': time.perf_counter() }) except Exception as e: logger.warning(f"Attention hook failed for layer {layer_idx}: {e}") elapsed_ms = (time.perf_counter() - start_time) * 1000 self.timing_buffer.append({ 'layer_idx': layer_idx, 'time_ms': elapsed_ms, 'stage': 'attention' }) return hook def _create_residual_hook(self, layer_idx: int): """ Create forward hook to capture residual stream norms. For transformer layers, the output includes the hidden states (residual stream). We compute ||x_l|| to track representation magnitude. """ def hook(module, input, output): if not self.capturing: return try: # Output is typically (hidden_states, ...) or just hidden_states hidden_states = output[0] if isinstance(output, tuple) else output if torch.is_tensor(hidden_states): # Compute L2 norm across the hidden dimension # Shape: [batch_size, seq_len, hidden_dim] -> [batch_size, seq_len] residual_norm = torch.norm(hidden_states, p=2, dim=-1) # Store mean norm across batch and sequence mean_norm = residual_norm.mean().item() self.residual_buffer.append({ 'layer_idx': layer_idx, 'norm': mean_norm, 'timestamp': time.perf_counter() }) except Exception as e: logger.warning(f"Residual hook failed for layer {layer_idx}: {e}") return hook def attach_hooks(self): """Attach forward hooks to all transformer layers""" logger.info(f"Attaching instrumentation hooks to {self.num_layers} layers...") # Get model layers based on architecture # Most models: model.transformer.h (GPT-2, CodeGen) or model.model.layers (Llama) if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): layers = self.model.transformer.h elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'): layers = self.model.model.layers else: logger.error("Could not find transformer layers in model") return for layer_idx, layer in enumerate(layers): # Attention hook attn_hook = self._create_attention_hook(layer_idx) handle = layer.register_forward_hook(attn_hook) self.hook_handles.append(handle) # Residual hook (attach to layer output) res_hook = self._create_residual_hook(layer_idx) handle = layer.register_forward_hook(res_hook) self.hook_handles.append(handle) logger.info(f"✅ Attached {len(self.hook_handles)} hooks") def remove_hooks(self): """Remove all forward hooks""" for handle in self.hook_handles: handle.remove() self.hook_handles = [] logger.info("Removed instrumentation hooks") def capture(self): """Context manager for capturing generation""" class CaptureContext: def __init__(self, instrumentor): self.instrumentor = instrumentor def __enter__(self): self.instrumentor.start_capture() return self.instrumentor def __exit__(self, exc_type, exc_val, exc_tb): self.instrumentor.stop_capture() return False return CaptureContext(self) def start_capture(self): """Start capturing data""" self.capturing = True self.start_time = time.perf_counter() self.clear_buffers() self.attach_hooks() logger.info("Started instrumentation capture") def stop_capture(self): """Stop capturing data""" self.capturing = False self.remove_hooks() logger.info("Stopped instrumentation capture") def clear_buffers(self): """Clear all capture buffers""" self.attention_buffer = [] self.residual_buffer = [] self.timing_buffer = [] self.logits_buffer = [] def compute_token_metadata(self, token_ids: torch.Tensor, logits: torch.Tensor, position: int) -> TokenMetadata: """ Compute metadata for a single token from logits. Args: token_ids: Generated token IDs [batch_size] logits: Model logits [batch_size, vocab_size] position: Position in sequence Returns: TokenMetadata with probabilities, entropy, top-k alternatives """ # Get probabilities via softmax probs = torch.softmax(logits[0], dim=-1) # [vocab_size] # Get generated token info token_id = token_ids[0].item() token_text = self.tokenizer.decode([token_id]) token_prob = probs[token_id].item() logprob = np.log(token_prob + 1e-10) # Compute entropy # H = -sum(p * log(p)) entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Get top-k alternatives top_k = 5 top_probs, top_indices = torch.topk(probs, k=top_k) top_k_tokens = [ (self.tokenizer.decode([idx.item()]), prob.item()) for idx, prob in zip(top_indices, top_probs) ] # Byte length byte_length = len(token_text.encode('utf-8')) return TokenMetadata( token_id=token_id, text=token_text, position=position, logprob=logprob, entropy=entropy, top_k_tokens=top_k_tokens, byte_length=byte_length, timestamp_ms=(time.perf_counter() - self.start_time) * 1000 ) def process_buffers(self) -> Tuple[torch.Tensor, List[List[LayerMetadata]]]: """ Process captured buffers into structured tensors. Returns: attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] layer_metadata: [num_tokens][num_layers] """ # Group attention by token step # Each forward pass captures attention for all layers # Estimate number of tokens from buffer size # Each token generates num_layers attention captures num_tokens = len(self.attention_buffer) // self.num_layers if self.attention_buffer else 0 if num_tokens == 0: logger.warning("No attention data captured") return None, [] # Organize attention tensors by token and layer attention_list = [] layer_metadata_list = [] for token_idx in range(num_tokens): token_attentions = [] token_layer_meta = [] for layer_idx in range(self.num_layers): buffer_idx = token_idx * self.num_layers + layer_idx if buffer_idx < len(self.attention_buffer): attn_data = self.attention_buffer[buffer_idx] token_attentions.append(attn_data['weights']) # Get residual norm residual_norm = 0.0 if buffer_idx < len(self.residual_buffer): residual_norm = self.residual_buffer[buffer_idx]['norm'] # Get timing time_ms = 0.0 if buffer_idx < len(self.timing_buffer): time_ms = self.timing_buffer[buffer_idx]['time_ms'] token_layer_meta.append(LayerMetadata( layer_idx=layer_idx, residual_norm=residual_norm, time_ms=time_ms )) if token_attentions: # Stack layer attentions: [num_layers, num_heads, seq_len, seq_len] attention_list.append(torch.stack(token_attentions)) layer_metadata_list.append(token_layer_meta) # Stack token attentions with padding for varying sequence lengths # During autoregressive generation, seq_len grows with each token if attention_list: # Find maximum sequence length across all tokens max_seq_len = max(attn.shape[-1] for attn in attention_list) # Pad all tensors to max_seq_len padded_attentions = [] for attn in attention_list: # attn shape: [num_layers, num_heads, seq_len, seq_len] current_seq_len = attn.shape[-1] if current_seq_len < max_seq_len: pad_size = max_seq_len - current_seq_len # Create zero tensor with correct dtype for padding pad_shape = list(attn.shape) pad_shape[-1] = max_seq_len pad_shape[-2] = max_seq_len padded = torch.zeros(pad_shape, dtype=attn.dtype, device=attn.device) # Copy original data into padded tensor padded[..., :current_seq_len, :current_seq_len] = attn attn = padded padded_attentions.append(attn) # Now stack: [num_tokens, num_layers, num_heads, max_seq_len, max_seq_len] attention_tensor = torch.stack(padded_attentions) else: attention_tensor = None return attention_tensor, layer_metadata_list def get_data(self, run_id: str, prompt: str, max_tokens: int, temperature: float, seed: int, tokens: List[TokenMetadata], top_k: Optional[int] = None, top_p: Optional[float] = None) -> InstrumentationData: """ Package all captured data into InstrumentationData structure. Args: run_id: Unique run identifier prompt: Original prompt max_tokens: Max tokens setting temperature: Temperature setting seed: Random seed used tokens: List of TokenMetadata for generated tokens top_k: Top-k sampling parameter top_p: Top-p sampling parameter Returns: InstrumentationData with all captured tensors and metadata """ # Process buffers attention_tensor, layer_metadata = self.process_buffers() # Calculate total time total_time_ms = (time.perf_counter() - self.start_time) * 1000 if self.start_time else 0.0 # Get sequence length from attention tensor seq_length = attention_tensor.shape[-1] if attention_tensor is not None else 0 data = InstrumentationData( run_id=run_id, seed=seed, model_name=self.model.config._name_or_path, timestamp=datetime.now().timestamp(), prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, tokens=tokens, attention_tensors=attention_tensor, logits_history=None, # Could capture this if needed layer_metadata=layer_metadata, total_time_ms=total_time_ms, num_layers=self.num_layers, num_heads=self.num_heads, seq_length=seq_length ) logger.info(f"Instrumentation data: {len(tokens)} tokens, " f"{self.num_layers} layers, {self.num_heads} heads, " f"seq_len={seq_length}, total_time={total_time_ms:.1f}ms") return data