""" Token Attribution utility using Integrated Gradients. Provides gradient-based attribution to identify which input tokens most influenced the model's output prediction. """ import torch import torch.nn.functional as F from typing import List, Dict, Any, Tuple, Optional def compute_integrated_gradients( model, tokenizer, text: str, target_token_id: Optional[int] = None, n_steps: int = 50, baseline_type: str = 'pad' ) -> Dict[str, Any]: """ Compute Integrated Gradients attribution for input tokens. This method computes how much each input token contributes to the model's prediction of the target token (or the top predicted token if not specified). Args: model: HuggingFace transformer model tokenizer: Tokenizer for the model text: Input text to analyze target_token_id: Optional specific token ID to compute attribution for. If None, uses the model's top predicted token. n_steps: Number of interpolation steps (higher = more accurate, slower) baseline_type: Type of baseline embedding ('pad', 'zero', 'mask') Returns: Dict with: - 'tokens': List of input token strings - 'token_ids': List of input token IDs - 'attributions': List of attribution scores (one per token) - 'normalized_attributions': Attribution scores normalized to [0, 1] - 'target_token': The token being attributed (string) - 'target_token_id': The token ID being attributed """ model.eval() device = next(model.parameters()).device # Tokenize input inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"].to(device) # Get embedding layer if hasattr(model, 'transformer'): # GPT-2 style embedding_layer = model.transformer.wte elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): # LLaMA/Qwen style embedding_layer = model.model.embed_tokens else: raise ValueError("Could not find embedding layer in model") # Get input embeddings input_embeds = embedding_layer(input_ids) # [1, seq_len, hidden_dim] # Create baseline embeddings if baseline_type == 'pad': pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 baseline_ids = torch.full_like(input_ids, pad_token_id) baseline_embeds = embedding_layer(baseline_ids) elif baseline_type == 'zero': baseline_embeds = torch.zeros_like(input_embeds) else: # 'mask' mask_token_id = getattr(tokenizer, 'mask_token_id', tokenizer.unk_token_id or 0) baseline_ids = torch.full_like(input_ids, mask_token_id) baseline_embeds = embedding_layer(baseline_ids) # If no target specified, get the model's top prediction if target_token_id is None: with torch.no_grad(): outputs = model(inputs_embeds=input_embeds) logits = outputs.logits[0, -1, :] # [vocab_size] target_token_id = logits.argmax().item() target_token = tokenizer.decode([target_token_id]) # Compute integrated gradients # We interpolate between baseline and input embeddings scaled_inputs = [] for step in range(n_steps + 1): alpha = step / n_steps scaled_input = baseline_embeds + alpha * (input_embeds - baseline_embeds) scaled_inputs.append(scaled_input) # Stack all scaled inputs scaled_inputs = torch.cat(scaled_inputs, dim=0) # [n_steps+1, seq_len, hidden_dim] # Enable gradients scaled_inputs.requires_grad_(True) # Forward pass for all scaled inputs # Process in batches if memory is a concern batch_size = min(n_steps + 1, 10) # Process 10 at a time all_grads = [] for i in range(0, n_steps + 1, batch_size): batch_inputs = scaled_inputs[i:i + batch_size] batch_inputs = batch_inputs.detach().requires_grad_(True) outputs = model(inputs_embeds=batch_inputs) # Get logits for the target token at the last position target_logits = outputs.logits[:, -1, target_token_id] # [batch_size] # Sum and backprop target_logits.sum().backward() # Collect gradients all_grads.append(batch_inputs.grad.detach()) # Concatenate all gradients gradients = torch.cat(all_grads, dim=0) # [n_steps+1, seq_len, hidden_dim] # Average gradients (Riemann sum approximation) avg_gradients = gradients.mean(dim=0) # [seq_len, hidden_dim] # Compute integrated gradients: (input - baseline) * avg_gradient # Then sum over hidden dimension to get per-token attribution delta = (input_embeds - baseline_embeds).squeeze(0) # [seq_len, hidden_dim] attributions = (delta * avg_gradients).sum(dim=-1) # [seq_len] # Convert to list attributions_list = attributions.detach().cpu().tolist() # Normalize to [0, 1] for visualization attr_abs = [abs(a) for a in attributions_list] max_attr = max(attr_abs) if attr_abs else 1.0 normalized = [a / max_attr if max_attr > 0 else 0 for a in attr_abs] # Get token strings tokens = [tokenizer.decode([tid]) for tid in input_ids[0].tolist()] return { 'tokens': tokens, 'token_ids': input_ids[0].tolist(), 'attributions': attributions_list, 'normalized_attributions': normalized, 'target_token': target_token, 'target_token_id': target_token_id } def compute_simple_gradient_attribution( model, tokenizer, text: str, target_token_id: Optional[int] = None ) -> Dict[str, Any]: """ Compute simple gradient-based attribution (faster than Integrated Gradients). This is a simpler approach that just computes the gradient of the output with respect to the input embeddings in a single pass. Args: model: HuggingFace transformer model tokenizer: Tokenizer for the model text: Input text to analyze target_token_id: Optional specific token ID to compute attribution for Returns: Dict with attribution information """ model.eval() device = next(model.parameters()).device # Tokenize input inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"].to(device) # Get embedding layer if hasattr(model, 'transformer'): embedding_layer = model.transformer.wte elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'): embedding_layer = model.model.embed_tokens else: raise ValueError("Could not find embedding layer in model") # Get input embeddings and enable gradients input_embeds = embedding_layer(input_ids) input_embeds = input_embeds.detach().requires_grad_(True) # Forward pass outputs = model(inputs_embeds=input_embeds) logits = outputs.logits[0, -1, :] # [vocab_size] # If no target specified, use top prediction if target_token_id is None: target_token_id = logits.argmax().item() target_token = tokenizer.decode([target_token_id]) # Backprop from target logit target_logit = logits[target_token_id] target_logit.backward() # Get gradients and compute attribution (L2 norm over hidden dim) gradients = input_embeds.grad.squeeze(0) # [seq_len, hidden_dim] attributions = gradients.norm(dim=-1) # [seq_len] # Convert to list attributions_list = attributions.detach().cpu().tolist() # Normalize max_attr = max(attributions_list) if attributions_list else 1.0 normalized = [a / max_attr if max_attr > 0 else 0 for a in attributions_list] # Get token strings tokens = [tokenizer.decode([tid]) for tid in input_ids[0].tolist()] return { 'tokens': tokens, 'token_ids': input_ids[0].tolist(), 'attributions': attributions_list, 'normalized_attributions': normalized, 'target_token': target_token, 'target_token_id': target_token_id } def create_attribution_visualization_data(attribution_result: Dict[str, Any]) -> List[Dict[str, Any]]: """ Format attribution results for visualization. Args: attribution_result: Output from compute_integrated_gradients or compute_simple_gradient_attribution Returns: List of dicts with token info and color intensity for visualization """ tokens = attribution_result['tokens'] normalized = attribution_result['normalized_attributions'] raw = attribution_result['attributions'] viz_data = [] for i, (tok, norm, raw_val) in enumerate(zip(tokens, normalized, raw)): # Map normalized value to color intensity (0 = white, 1 = deep color) # Use a blue-to-red scale where positive = red, negative = blue if raw_val >= 0: r = int(255) g = int(255 * (1 - norm * 0.7)) b = int(255 * (1 - norm * 0.7)) else: r = int(255 * (1 - norm * 0.7)) g = int(255 * (1 - norm * 0.7)) b = int(255) viz_data.append({ 'token': tok, 'index': i, 'attribution': raw_val, 'normalized': norm, 'color': f'rgb({r},{g},{b})', 'text_color': '#000000' if norm < 0.5 else '#ffffff' }) return viz_data