import torch import numpy as np from .batch_config import get_batch_chunk_size class AttributionEngine: def __init__(self, model_manager): self.manager = model_manager self.hook_handles = [] self.outputs = None self.input_ids = None def _parse_node(self, node): if isinstance(node, int): return node, 'post' if isinstance(node, (tuple, list)) and len(node) == 2: return node[0], node[1] raise ValueError(f"Invalid node format: {node}") def _forward_part1(self, layer_module, hidden_states, position_embeddings=None, attention_mask=None): """ Executes: Norm -> Attn -> Residual Add Returns: resid_mid """ return self.manager.decomposer.forward_part1(layer_module, hidden_states, position_embeddings, attention_mask) def _forward_part2(self, layer_module, hidden_states): """ Executes: Norm -> MLP -> Residual Add Returns: resid_post """ return self.manager.decomposer.forward_part2(layer_module, hidden_states) def _hook_hidden_activation(self, module, input, output): """ Hook to save activation and enable gradient retention. """ if isinstance(output, tuple): output = output[0] # Save output to the module for later access module.output = output if module.output.requires_grad: module.output.retain_grad() def _hook_mid_activation(self, module, input, output): """ Hook to capture resid_mid at the input of post_attention_layernorm. """ # input is a tuple (tensor,) val = input[0] # Attach to the module (which is the Norm layer) module.mid_activation = val if module.mid_activation.requires_grad: module.mid_activation.retain_grad() def register_hooks(self, capture_mid=False): """ Register forward hooks on all model layers. """ self.remove_hooks() # Clear existing model = self.manager.get_model() if not model: raise ValueError("Model not loaded yet.") # Assuming generic structure model.model.layers (common in HF Qwen, Llama, etc.) if hasattr(model, "model") and hasattr(model.model, "layers"): layers = model.model.layers elif hasattr(model, "layers"): # Some archs layers = model.layers else: layers = [] for layer in layers: # Hook Output (Resid Post) handle = layer.register_forward_hook(self._hook_hidden_activation) self.hook_handles.append(handle) # Hook Mid (Resid Mid) - Pre-hook on Post-Attn Norm if capture_mid: # Use decomposer to find the correct module for mid activation mid_module = self.manager.decomposer.get_mid_activation_module(layer) if mid_module: # Use forward hook to get input? # register_forward_hook receives (module, input, output) # input is (resid_mid,) handle_mid = mid_module.register_forward_hook(self._hook_mid_activation) self.hook_handles.append(handle_mid) else: print(f"Warning: Could not identify mid-activation module for layer {layer}. Skipping mid hook.") def remove_hooks(self): for handle in self.hook_handles: handle.remove() self.hook_handles = [] def reset(self): """ Clears all internal state and specific temporary data from model layers. """ self.remove_hooks() self.outputs = None self.input_ids = None # Manually clear output tensors attached to layers to free graph model = self.manager.get_model() if model: layers = [] if hasattr(model, "model") and hasattr(model.model, "layers"): layers = model.model.layers elif hasattr(model, "layers"): layers = model.layers for layer in layers: if hasattr(layer, 'output'): del layer.output if hasattr(layer, 'post_attention_layernorm') and hasattr(layer.post_attention_layernorm, 'mid_activation'): del layer.post_attention_layernorm.mid_activation torch.cuda.empty_cache() def rerun_forward_pass(self, capture_mid=False): """ Re-run the forward pass using stored input_ids to get a fresh computation graph. This is needed before circuit computation to ensure clean gradients. """ if self.input_ids is None: raise ValueError("No input_ids stored. Run compute_logits first.") model = self.manager.get_model() # Re-register hooks to capture fresh activations self.register_hooks(capture_mid=capture_mid) # Re-create input embeddings with gradient tracking self.input_embeddings = model.get_input_embeddings()(self.input_ids).detach() self.input_embeddings.requires_grad_(True) # Forward pass - creates fresh computation graph self.outputs = model( inputs_embeds=self.input_embeddings, use_cache=False ) print("Re-ran forward pass for fresh computation graph.") def compute_logits(self, prompt, is_append_bos=False, topk=10, extra_token_ids=None, extra_token_strs=None, capture_mid=False): """ Section 1: Forward pass to get logits and top-k predictions. """ model = self.manager.get_model() tokenizer = self.manager.get_tokenizer() self.register_hooks(capture_mid=capture_mid) # Prepare input # We tokenize with add_special_tokens=False to manually control the BOS/Start token inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) input_ids = inputs.input_ids.to(model.device) if is_append_bos: # 1. Try explicit BOS bos_id = tokenizer.bos_token_id # 2. Try CLS (BERT-like) if bos_id is None: bos_id = tokenizer.cls_token_id # 3. Fallback: EOS (Often used as BOS/Separator in Llama/decoder-only models if BOS is missing) if bos_id is None: bos_id = tokenizer.eos_token_id if bos_id is not None: prefix = torch.tensor([[bos_id]], device=model.device) input_ids = torch.cat([prefix, input_ids], dim=1) print(f"Appended start token ID: {bos_id}") else: print("Warning: Append BOS requested but no suitable start token (BOS/CLS/EOS) found.") self.input_ids = input_ids # Embedding with gradients required for LRP base # We detach and enable gradients so we can compute attribution w.r.t input embeddings # even if the model is frozen/quantized. self.input_embeddings = model.get_input_embeddings()(self.input_ids).detach() self.input_embeddings.requires_grad_(True) # Forward pass # output_hidden_states=True is crucial for some LRP methods, # though we use hooks for "efficient" method mostly. self.outputs = model( inputs_embeds=self.input_embeddings, use_cache=False ) output_logits = self.outputs.logits last_logits = output_logits[0, -1, :] # Get Top-K sorted_logits, sorted_indices = torch.sort(last_logits, dim=-1, descending=True) # Formatted output topk_data = [] for i in range(topk): idx = sorted_indices[i].item() token_str = tokenizer.decode([idx]) logit_val = sorted_logits[i].item() topk_data.append({ "rank": i + 1, "token_id": idx, "token_str": token_str, "logit": logit_val }) # Handle Extra Tokens (if requested) if extra_token_ids or extra_token_strs: # Helper to find rank def get_rank(val, sorted_vals): # tensor search for rank # val is float, sorted_vals is tensor # find first index where sorted_vals < val is NOT true? # sorted_vals is descending # we want count of items > val return (sorted_vals > val).sum().item() + 1 processed_ids = set() # Process IDs if extra_token_ids: for tid in extra_token_ids: if tid < 0 or tid >= len(last_logits): continue if tid in processed_ids: continue logit_val = last_logits[tid].item() rank = get_rank(logit_val, sorted_logits) token_str = tokenizer.decode([tid]) topk_data.append({ "rank": rank, "token_id": tid, "token_str": token_str, "logit": logit_val, "is_extra": True }) processed_ids.add(tid) # Process Strings if extra_token_strs: print(f"DEBUG: Processing extra strs: {extra_token_strs}") for tstr in extra_token_strs: # Encode try: # Ensure we get list of ints encoded = tokenizer.encode(tstr, add_special_tokens=False) print(f"DEBUG: Encoded '{tstr}' -> {encoded} (Type: {type(encoded)})") if hasattr(encoded, 'tolist'): encoded = encoded.tolist() if len(encoded) == 0: print(f"DEBUG: Empty encoding for '{tstr}'") continue # Take first token tid = encoded[0] print(f"DEBUG: Using TID {tid} for '{tstr}'") if tid in processed_ids: print(f"DEBUG: TID {tid} already processed") continue logit_val = last_logits[tid].item() rank = get_rank(logit_val, sorted_logits) real_str = tokenizer.decode([tid]) print(f"DEBUG: Added extra token: {real_str} (ID: {tid}, Rank: {rank})") topk_data.append({ "rank": rank, "token_id": tid, "token_str": real_str, "logit": logit_val, "is_extra": True }) processed_ids.add(tid) except Exception as e: print(f"DEBUG: Error processing extra str '{tstr}': {e}") import traceback traceback.print_exc() # Sort combined data by rank for display consistency? # Or keep extras at the bottom? User request: "add more tokens in the existing top-50 table" # If we sort, they mix in. If they are rank 1000, they go to bottom. # But if they are rank 5 (and we showed top 10), they mix in. # Let's sort. topk_data.sort(key=lambda x: x['rank']) if self.input_ids is None: raise ValueError("Input IDs not found. Ensure compute_logits was run.") # Get input tokens for visualization # Robust token reconstruction ensuring spaces are preserved input_tokens = [] # convert_ids_to_tokens usually preserves the special characters (like Ġ or ) raw_tokens = tokenizer.convert_ids_to_tokens(self.input_ids[0]) for t in raw_tokens: # Handle bytes (common in tiktoken-based tokenizers like Qwen) if isinstance(t, bytes): try: t = t.decode('utf-8') except: # Fallback for weird bytes behavior t = str(t) # If it's a string, it might still have the special whitespace characters if isinstance(t, str): # Replace SentencePiece underline (U+2581) t = t.replace('\u2581', ' ') # Replace GPT-2/RoBERTa G-dot (U+0120) t = t.replace('\u0120', ' ') # Replace Newline char (U+010A) t = t.replace('\u010A', '\n') # Replace generic replacement char just in case t = t.replace('', '') input_tokens.append(t) return topk_data, last_logits, input_tokens def get_target_score(self, backprop_config): """ Calculates the target scalar score (e.g. logit diff) based on config. Returns the score tensor (attached to graph). """ if self.outputs is None: raise ValueError("Model outputs not computed. Call compute_logits first.") mode = backprop_config.get("mode", "max_logit") last_logits = self.outputs.logits[0, -1, :] sorted_logits, sorted_indices = torch.sort(last_logits, dim=-1, descending=True) target_token_id = backprop_config.get("target_token_id") if target_token_id is not None: target_logit = last_logits[target_token_id] else: target_logit = sorted_logits[0] # Default to Top 1 if mode == "max_logit": return target_logit elif mode == "logit_diff": strategy = backprop_config.get("strategy", "by_topk_avg") top_logit = target_logit if strategy == "by_ref_token": ref_id = backprop_config.get("ref_token_id") if ref_id is None: raise ValueError("ref_token_id required for strategy 'by_ref_token'") contrast_logit = last_logits[ref_id] target_logit = top_logit - contrast_logit elif strategy == "demean": target_logit = top_logit - last_logits.mean() elif strategy == "by_topk_avg": k = backprop_config.get("k", 10) # default K=10 k = min(k, len(sorted_logits)) contrast_logit = sorted_logits[:k].mean() target_logit = top_logit - contrast_logit return target_logit def run_backward_pass(self, backprop_config): """ Section 2 Part A: execute backward pass based on configuration. """ target_logit = self.get_target_score(backprop_config) if target_logit is None: raise ValueError(f"Invalid backprop configuration: {backprop_config}") # Clear previous gradients model = self.manager.get_model() model.zero_grad() # Also clear gradients on input embeddings if they exist if hasattr(self, 'input_embeddings') and self.input_embeddings is not None: if self.input_embeddings.grad is not None: self.input_embeddings.grad.zero_() # Clear gradients on intermediate activations (layer.output, mid_activation) # These are non-parameter tensors with retain_grad(), so model.zero_grad() does NOT clear them. # Without clearing, gradients accumulate across multiple backward passes. layers = [] if hasattr(model, "model") and hasattr(model.model, "layers"): layers = model.model.layers elif hasattr(model, "layers"): layers = model.layers for layer in layers: if hasattr(layer, 'output') and layer.output is not None: if hasattr(layer.output, 'grad') and layer.output.grad is not None: layer.output.grad = None # Also clear mid_activation grad if it exists mid_module = getattr(layer, 'post_attention_layernorm', None) if mid_module and hasattr(mid_module, 'mid_activation') and mid_module.mid_activation is not None: if hasattr(mid_module.mid_activation, 'grad') and mid_module.mid_activation.grad is not None: mid_module.mid_activation.grad = None # Run backward target_logit.backward(retain_graph=True) # retain_graph needed for interactive exploration where we run backward multiple times def compute_input_attribution(self, backprop_config): """ Compute input attribution (Input * Gradient). """ # Ensure correct LRP rule is active # The forward pass must have been run with the correct rule. # If we detect a mismatch, we must force a reload and ask user to re-run forward. # Use the currently loaded LRP rule as default (not hardcoded "Attn-LRP") # to avoid false mismatch when the frontend omits lrp_rule from backprop_config. default_rule = self.manager.current_lrp_rule or "Attn-LRP" required_rule = backprop_config.get("lrp_rule", default_rule) if self.manager.current_lrp_rule and self.manager.current_lrp_rule != required_rule: print(f"LRP Rule Mismatch detected (Current: {self.manager.current_lrp_rule}, Requested: {required_rule})") print(f"Reloading model {self.manager.current_model_path} with rule={required_rule}...") old_rule = self.manager.current_lrp_rule self.manager.load_model( model_path=self.manager.current_model_path, dtype=self.manager.current_dtype, lrp_rule=required_rule ) # Since the forward pass graph (self.outputs) was built with the OLD rule, # we cannot proceed. The user must re-run compute_logits. raise RuntimeError( f"LRP rule changed from '{old_rule}' to '{required_rule}'. " "The model has been reloaded. You MUST re-run 'compute_logits()' to rebuild the computation graph with the new rule, " "then call 'compute_input_attribution()' again." ) # backprop_config['target_token_id'] = 2877 self.run_backward_pass(backprop_config) # Calculate relevance: (input * grad).sum(-1) # self.input_embeddings is [Batch, Seq, Dim] if self.input_embeddings.grad is None: raise RuntimeError("No gradient found on input embeddings. Ensure compute_logits was run.") relevance = (self.input_embeddings * self.input_embeddings.grad).float().sum(-1).detach().cpu()[0] print(f"Computed input attribution with shape: {relevance}") # Return raw relevance return relevance.tolist() def compute_input_attribution_gradient(self, backprop_config): """ Compute input attribution using vanilla gradient method (Input * Gradient). This does NOT require LRP monkey-patching - uses standard PyTorch autograd. The gradient flows through normal attention/MLP without LRP decomposition rules. """ if self.outputs is None: raise RuntimeError("No forward pass found. Run compute_logits first.") # Run backward pass (works on vanilla model without LRP) self.run_backward_pass(backprop_config) # Calculate relevance: (input * grad).sum(-1) if self.input_embeddings.grad is None: raise RuntimeError("No gradient found on input embeddings. Ensure compute_logits was run.") relevance = (self.input_embeddings * self.input_embeddings.grad).float().sum(-1).detach().cpu()[0] print(f"Computed GRADIENT input attribution with shape: {relevance.shape}") return relevance.tolist() def compute_perturbation_eval(self, attribution_scores, k_values, target_token_id): """ Evaluate attribution quality by perturbing top-attributed tokens. For each k in k_values: 1. Sort tokens by |attribution score| descending 2. Take top-k token indices 3. Clone input embeddings, zero out those k tokens' embeddings 4. Run forward pass with perturbed embeddings 5. Check if the error token (target_token_id) is still top-1 Args: attribution_scores: list of floats (one per input token) k_values: list of int (e.g., [1, 3, 5, 10]) target_token_id: int - the error token ID to check Returns: list of result dicts for each k """ if self.input_ids is None or self.input_embeddings is None: raise RuntimeError("No forward pass found. Run compute_logits first.") model = self.manager.get_model() tokenizer = self.manager.get_tokenizer() device = model.device # Get original top-1 prediction for reference with torch.no_grad(): original_logits = model( inputs_embeds=self.input_embeddings.detach(), use_cache=False ).logits[0, -1, :] original_top1_id = original_logits.argmax().item() original_target_logit = original_logits[target_token_id].item() # Sort tokens by |attribution score| descending scores = torch.tensor(attribution_scores, dtype=torch.float32) sorted_indices = torch.argsort(scores.abs(), descending=True) seq_len = self.input_embeddings.shape[1] results = [] for k in k_values: k_clamped = min(k, seq_len) top_k_indices = sorted_indices[:k_clamped].tolist() # Get the token strings being perturbed perturbed_token_strs = [] for idx in top_k_indices: if idx < len(self.input_ids[0]): tid = self.input_ids[0][idx].item() perturbed_token_strs.append(tokenizer.decode([tid])) else: perturbed_token_strs.append("?") # Clone embeddings and zero out top-k tokens perturbed_embeddings = self.input_embeddings.detach().clone() for idx in top_k_indices: perturbed_embeddings[0, idx, :] = 0.0 # Forward pass with perturbed embeddings with torch.no_grad(): perturbed_logits = model( inputs_embeds=perturbed_embeddings, use_cache=False ).logits[0, -1, :] perturbed_top1_id = perturbed_logits.argmax().item() perturbed_top1_str = tokenizer.decode([perturbed_top1_id]) perturbed_target_logit = perturbed_logits[target_token_id].item() # Error is "fixed" if the target token is no longer top-1 error_fixed = (perturbed_top1_id != target_token_id) # Compute logit change logit_change = perturbed_target_logit - original_target_logit # Compute rank of target token after perturbation sorted_perturbed, sorted_perturbed_idx = torch.sort(perturbed_logits, descending=True) target_rank_after = (sorted_perturbed_idx == target_token_id).nonzero(as_tuple=True)[0].item() + 1 results.append({ "k": k, "perturbed_tokens": perturbed_token_strs, "perturbed_indices": top_k_indices, "new_top1_token_id": perturbed_top1_id, "new_top1_token_str": perturbed_top1_str, "error_fixed": error_fixed, "original_target_logit": round(original_target_logit, 4), "perturbed_target_logit": round(perturbed_target_logit, 4), "logit_change": round(logit_change, 4), "target_rank_after": target_rank_after }) print(f"Perturbation k={k}: error_fixed={error_fixed}, " f"new_top1='{perturbed_top1_str}' (ID={perturbed_top1_id}), " f"logit_change={logit_change:.4f}, target_rank={target_rank_after}") del perturbed_embeddings torch.cuda.empty_cache() return results def compute_perturbation_manual(self, perturb_indices, target_token_id): """ Evaluate attribution by perturbing manually selected token positions. Args: perturb_indices: list of int - token position indices to zero out target_token_id: int - the error token ID to check Returns: dict with perturbation result """ if self.input_ids is None or self.input_embeddings is None: raise RuntimeError("No forward pass found. Run compute_logits first.") model = self.manager.get_model() tokenizer = self.manager.get_tokenizer() seq_len = self.input_embeddings.shape[1] # Validate indices valid_indices = [idx for idx in perturb_indices if 0 <= idx < seq_len] if len(valid_indices) == 0: raise ValueError("No valid token indices provided.") # Get original top-1 prediction for reference with torch.no_grad(): original_logits = model( inputs_embeds=self.input_embeddings.detach(), use_cache=False ).logits[0, -1, :] original_top1_id = original_logits.argmax().item() original_target_logit = original_logits[target_token_id].item() # Get the token strings being perturbed perturbed_token_strs = [] for idx in valid_indices: if idx < len(self.input_ids[0]): tid = self.input_ids[0][idx].item() perturbed_token_strs.append(tokenizer.decode([tid])) else: perturbed_token_strs.append("?") # Clone embeddings and zero out selected tokens perturbed_embeddings = self.input_embeddings.detach().clone() for idx in valid_indices: perturbed_embeddings[0, idx, :] = 0.0 # Forward pass with perturbed embeddings with torch.no_grad(): perturbed_logits = model( inputs_embeds=perturbed_embeddings, use_cache=False ).logits[0, -1, :] perturbed_top1_id = perturbed_logits.argmax().item() perturbed_top1_str = tokenizer.decode([perturbed_top1_id]) perturbed_target_logit = perturbed_logits[target_token_id].item() # Error is "fixed" if the target token is no longer top-1 error_fixed = (perturbed_top1_id != target_token_id) # Compute logit change logit_change = perturbed_target_logit - original_target_logit # Compute rank of target token after perturbation sorted_perturbed, sorted_perturbed_idx = torch.sort(perturbed_logits, descending=True) target_rank_after = (sorted_perturbed_idx == target_token_id).nonzero(as_tuple=True)[0].item() + 1 result = { "k": len(valid_indices), "perturbed_tokens": perturbed_token_strs, "perturbed_indices": valid_indices, "new_top1_token_id": perturbed_top1_id, "new_top1_token_str": perturbed_top1_str, "error_fixed": error_fixed, "original_target_logit": round(original_target_logit, 4), "perturbed_target_logit": round(perturbed_target_logit, 4), "logit_change": round(logit_change, 4), "target_rank_after": target_rank_after } print(f"Manual Perturbation ({len(valid_indices)} tokens): error_fixed={error_fixed}, " f"new_top1='{perturbed_top1_str}' (ID={perturbed_top1_id}), " f"logit_change={logit_change:.4f}, target_rank={target_rank_after}") del perturbed_embeddings torch.cuda.empty_cache() return result def compute_connection_matrix_gen(self, source, target, node_threshold=None): """ Section 2 Part B: Compute Token-to-Token interaction matrix between two nodes. Generator version that yields progress. source, target: int (layer idx) or tuple (layer_idx, 'mid'/'post') """ source_layer_idx, source_type = self._parse_node(source) target_layer_idx, target_type = self._parse_node(target) model = self.manager.get_model() layers = model.model.layers target_layer = layers[target_layer_idx] # 1. Identify Source Tensor if source_layer_idx == -1: source_tensor = self.input_embeddings else: layer = layers[source_layer_idx] if source_type == 'mid': # Use Decomposer to get module mid_mod = self.manager.decomposer.get_mid_activation_module(layer) if not mid_mod: raise ValueError(f"Decomposer could not identify mid-activation module for layer {source_layer_idx}") source_tensor = getattr(mid_mod, 'mid_activation', None) if source_tensor is None: raise ValueError(f"Mid activation for layer {source_layer_idx} not captured. Enable capture_mid in compute_logits.") else: source_tensor = layer.output # 2. Identify Target Tensor and Gradient if target_type == 'mid': # We need the gradient at the mid point (input to post_attn_norm) mid_mod = self.manager.decomposer.get_mid_activation_module(target_layer) if not mid_mod: raise ValueError(f"Decomposer could not identify mid-activation module for target {target_layer_idx}") target_tensor = getattr(mid_mod, 'mid_activation', None) if target_tensor is None: raise ValueError(f"Mid activation for target {target_layer_idx} not captured.") else: target_tensor = target_layer.output target_grad = target_tensor.grad if target_grad is None: print(f"WARNING: target_grad is None for target layer {target_layer_idx} (type={target_type}). " f"This will result in an all-zero interaction matrix. " f"Ensure backward pass was run and hooks captured activations correctly.") # Disable gradient checkpointing temporarily was_checkpointing = model.is_gradient_checkpointing if was_checkpointing: model.gradient_checkpointing_disable() try: # Prepare Input target_layer_input = source_tensor.detach() batch_size, seq_len, hidden_dim = target_layer_input.shape # Target Real Relevance if target_grad is not None: real_target_rel = (target_tensor * target_grad).sum(dim=-1)[0] else: real_target_rel = torch.zeros(seq_len, device=model.device) # Filter Indices total_params = model.num_parameters() if node_threshold is None: node_threshold = 0.01 if node_threshold > 0: indices_to_compute = torch.nonzero(real_target_rel.abs() > node_threshold).squeeze(-1).tolist() if isinstance(indices_to_compute, int): indices_to_compute = [indices_to_compute] print(f"DEBUG: Node Threshold {node_threshold}. Computing for {len(indices_to_compute)}/{seq_len} nodes.") else: indices_to_compute = list(range(seq_len)) # Fixed Position IDs (for Rotary) position_ids = torch.arange(0, seq_len, dtype=torch.long, device=model.device).unsqueeze(0) # Construct Operation Sequence ops = [] if source_layer_idx != -1: if source_type == 'mid': ops.append(('part2', layers[source_layer_idx])) # Intermediate Layers for i in range(source_layer_idx + 1, target_layer_idx): ops.append(('part1', layers[i])) ops.append(('part2', layers[i])) # Target Layer if target_layer_idx > source_layer_idx: ops.append(('part1', layers[target_layer_idx])) if target_type == 'post': ops.append(('part2', layers[target_layer_idx])) elif target_layer_idx == source_layer_idx: pass # Already handled or identity elif source_layer_idx == -1: # Special case: source is embeddings, target is 0 # Range was (0,0) empty. # Need to add target 0 parts ops.append(('part1', layers[target_layer_idx])) if target_type == 'post': ops.append(('part2', layers[target_layer_idx])) # Pre-calc Rotary Embedding (using dummy execution or helper) # We assume rotary depends only on position_ids and shape rotary_emb = None if hasattr(model.model, 'rotary_emb'): rotary_emb = model.model.rotary_emb(target_layer_input, position_ids) elif hasattr(model.model, 'rotary_embs') and 'full_attention' in model.model.rotary_embs: rotary_emb = model.model.rotary_embs['full_attention'](target_layer_input, position_ids) elif hasattr(model.model, 'rotary_embs') and len(model.model.rotary_embs) > 0: rotary_emb = list(model.model.rotary_embs.values())[0](target_layer_input, position_ids) # Chunk Processing current_dtype = target_layer_input.dtype BATCH_CHUNK_SIZE = get_batch_chunk_size(total_params, current_dtype) token_interaction = torch.zeros(seq_len, seq_len, device=model.device) target_grad_full = target_grad # Alias total_items = len(indices_to_compute) print(f"DEBUG: BATCH_CHUNK_SIZE={BATCH_CHUNK_SIZE}, total_items={total_items}, seq_len={seq_len}, params={total_params/1e9:.2f}B, dtype={current_dtype}") for i in range(0, total_items, BATCH_CHUNK_SIZE): yield {"type": "progress", "current": i, "total": total_items} chunk_indices = indices_to_compute[i : i + BATCH_CHUNK_SIZE] current_batch_size = len(chunk_indices) expanded_input = target_layer_input.expand(current_batch_size, seq_len, hidden_dim).clone().requires_grad_(True) # Execute Ops hidden_states = expanded_input for op_type, layer_mod in ops: if op_type == 'part1': hidden_states = self._forward_part1(layer_mod, hidden_states, position_embeddings=rotary_emb) else: hidden_states = self._forward_part2(layer_mod, hidden_states) reconstructed_output = hidden_states # Backward grad_output_chunk = torch.zeros(current_batch_size, seq_len, hidden_dim, dtype=reconstructed_output.dtype, device=model.device) for batch_idx, global_idx in enumerate(chunk_indices): if target_grad_full is not None: grad_output_chunk[batch_idx, global_idx, :] = target_grad_full[0, global_idx, :] grad_input = torch.autograd.grad(outputs=reconstructed_output, inputs=expanded_input, grad_outputs=grad_output_chunk, retain_graph=False)[0] chunk_relevance = (grad_input * expanded_input).sum(dim=-1) token_interaction[chunk_indices, :] = chunk_relevance.detach().to(token_interaction.dtype) del expanded_input, hidden_states, reconstructed_output, grad_output_chunk, grad_input, chunk_relevance torch.cuda.empty_cache() # Source Real Relevance if source_layer_idx == -1: if self.input_embeddings.grad is not None: real_source_rel = (self.input_embeddings * self.input_embeddings.grad).sum(dim=-1)[0] else: real_source_rel = torch.zeros(seq_len, device=model.device) else: if source_tensor.grad is not None: real_source_rel = (source_tensor * source_tensor.grad).sum(dim=-1)[0] else: real_source_rel = torch.zeros(seq_len, device=model.device) yield { "type": "result", "payload": { "matrix": token_interaction.detach().float().cpu().numpy(), "real_target_rel": real_target_rel.detach().float().cpu().numpy(), "real_source_rel": real_source_rel.detach().float().cpu().numpy() } } finally: if was_checkpointing: model.gradient_checkpointing_enable() def compute_connection_matrix(self, source, target): for item in self.compute_connection_matrix_gen(source, target): if item.get("type") == "result": return item["payload"] return None