| import torch |
| import numpy as np |
| from .batch_config import get_batch_chunk_size |
|
|
| class AttributionEngine: |
| def __init__(self, model_manager): |
| self.manager = model_manager |
| self.hook_handles = [] |
| self.outputs = None |
| self.input_ids = None |
|
|
| def _parse_node(self, node): |
| if isinstance(node, int): |
| return node, 'post' |
| if isinstance(node, (tuple, list)) and len(node) == 2: |
| return node[0], node[1] |
| raise ValueError(f"Invalid node format: {node}") |
|
|
| def _forward_part1(self, layer_module, hidden_states, position_embeddings=None, attention_mask=None): |
| """ |
| Executes: Norm -> Attn -> Residual Add |
| Returns: resid_mid |
| """ |
| return self.manager.decomposer.forward_part1(layer_module, hidden_states, position_embeddings, attention_mask) |
|
|
| def _forward_part2(self, layer_module, hidden_states): |
| """ |
| Executes: Norm -> MLP -> Residual Add |
| Returns: resid_post |
| """ |
| return self.manager.decomposer.forward_part2(layer_module, hidden_states) |
|
|
| def _hook_hidden_activation(self, module, input, output): |
| """ |
| Hook to save activation and enable gradient retention. |
| """ |
| if isinstance(output, tuple): |
| output = output[0] |
|
|
| |
| module.output = output |
| if module.output.requires_grad: |
| module.output.retain_grad() |
|
|
| def _hook_mid_activation(self, module, input, output): |
| """ |
| Hook to capture resid_mid at the input of post_attention_layernorm. |
| """ |
| |
| val = input[0] |
| |
| module.mid_activation = val |
| if module.mid_activation.requires_grad: |
| module.mid_activation.retain_grad() |
|
|
| def register_hooks(self, capture_mid=False): |
| """ |
| Register forward hooks on all model layers. |
| """ |
| self.remove_hooks() |
| model = self.manager.get_model() |
| if not model: |
| raise ValueError("Model not loaded yet.") |
|
|
| |
| if hasattr(model, "model") and hasattr(model.model, "layers"): |
| layers = model.model.layers |
| elif hasattr(model, "layers"): |
| layers = model.layers |
| else: |
| layers = [] |
|
|
| for layer in layers: |
| |
| handle = layer.register_forward_hook(self._hook_hidden_activation) |
| self.hook_handles.append(handle) |
|
|
| |
| if capture_mid: |
| |
| mid_module = self.manager.decomposer.get_mid_activation_module(layer) |
|
|
| if mid_module: |
| |
| |
| |
| handle_mid = mid_module.register_forward_hook(self._hook_mid_activation) |
| self.hook_handles.append(handle_mid) |
| else: |
| print(f"Warning: Could not identify mid-activation module for layer {layer}. Skipping mid hook.") |
|
|
| def remove_hooks(self): |
| for handle in self.hook_handles: |
| handle.remove() |
| self.hook_handles = [] |
|
|
| def reset(self): |
| """ |
| Clears all internal state and specific temporary data from model layers. |
| """ |
| self.remove_hooks() |
| self.outputs = None |
| self.input_ids = None |
|
|
| |
| model = self.manager.get_model() |
| if model: |
| layers = [] |
| if hasattr(model, "model") and hasattr(model.model, "layers"): |
| layers = model.model.layers |
| elif hasattr(model, "layers"): |
| layers = model.layers |
|
|
| for layer in layers: |
| if hasattr(layer, 'output'): |
| del layer.output |
| if hasattr(layer, 'post_attention_layernorm') and hasattr(layer.post_attention_layernorm, 'mid_activation'): |
| del layer.post_attention_layernorm.mid_activation |
|
|
| torch.cuda.empty_cache() |
|
|
|
|
| def rerun_forward_pass(self, capture_mid=False): |
| """ |
| Re-run the forward pass using stored input_ids to get a fresh computation graph. |
| This is needed before circuit computation to ensure clean gradients. |
| """ |
| if self.input_ids is None: |
| raise ValueError("No input_ids stored. Run compute_logits first.") |
|
|
| model = self.manager.get_model() |
|
|
| |
| self.register_hooks(capture_mid=capture_mid) |
|
|
| |
| self.input_embeddings = model.get_input_embeddings()(self.input_ids).detach() |
| self.input_embeddings.requires_grad_(True) |
|
|
| |
| self.outputs = model( |
| inputs_embeds=self.input_embeddings, |
| use_cache=False |
| ) |
|
|
| print("Re-ran forward pass for fresh computation graph.") |
|
|
| def compute_logits(self, prompt, is_append_bos=False, topk=10, extra_token_ids=None, extra_token_strs=None, capture_mid=False): |
| """ |
| Section 1: Forward pass to get logits and top-k predictions. |
| """ |
| model = self.manager.get_model() |
| tokenizer = self.manager.get_tokenizer() |
|
|
| self.register_hooks(capture_mid=capture_mid) |
|
|
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) |
| input_ids = inputs.input_ids.to(model.device) |
|
|
| if is_append_bos: |
| |
| bos_id = tokenizer.bos_token_id |
|
|
| |
| if bos_id is None: |
| bos_id = tokenizer.cls_token_id |
|
|
| |
| if bos_id is None: |
| bos_id = tokenizer.eos_token_id |
|
|
| if bos_id is not None: |
| prefix = torch.tensor([[bos_id]], device=model.device) |
| input_ids = torch.cat([prefix, input_ids], dim=1) |
| print(f"Appended start token ID: {bos_id}") |
| else: |
| print("Warning: Append BOS requested but no suitable start token (BOS/CLS/EOS) found.") |
|
|
| self.input_ids = input_ids |
|
|
| |
| |
| |
| self.input_embeddings = model.get_input_embeddings()(self.input_ids).detach() |
| self.input_embeddings.requires_grad_(True) |
|
|
| |
| |
| |
| self.outputs = model( |
| inputs_embeds=self.input_embeddings, |
| use_cache=False |
| ) |
|
|
| output_logits = self.outputs.logits |
| last_logits = output_logits[0, -1, :] |
|
|
| |
| sorted_logits, sorted_indices = torch.sort(last_logits, dim=-1, descending=True) |
|
|
| |
| topk_data = [] |
| for i in range(topk): |
| idx = sorted_indices[i].item() |
| token_str = tokenizer.decode([idx]) |
| logit_val = sorted_logits[i].item() |
| topk_data.append({ |
| "rank": i + 1, |
| "token_id": idx, |
| "token_str": token_str, |
| "logit": logit_val |
| }) |
|
|
| |
| if extra_token_ids or extra_token_strs: |
| |
| def get_rank(val, sorted_vals): |
| |
| |
| |
| |
| |
| return (sorted_vals > val).sum().item() + 1 |
|
|
| processed_ids = set() |
|
|
| |
| if extra_token_ids: |
| for tid in extra_token_ids: |
| if tid < 0 or tid >= len(last_logits): continue |
| if tid in processed_ids: continue |
|
|
| logit_val = last_logits[tid].item() |
| rank = get_rank(logit_val, sorted_logits) |
| token_str = tokenizer.decode([tid]) |
|
|
| topk_data.append({ |
| "rank": rank, |
| "token_id": tid, |
| "token_str": token_str, |
| "logit": logit_val, |
| "is_extra": True |
| }) |
| processed_ids.add(tid) |
|
|
| |
| if extra_token_strs: |
| print(f"DEBUG: Processing extra strs: {extra_token_strs}") |
| for tstr in extra_token_strs: |
| |
| try: |
| |
| encoded = tokenizer.encode(tstr, add_special_tokens=False) |
| print(f"DEBUG: Encoded '{tstr}' -> {encoded} (Type: {type(encoded)})") |
|
|
| if hasattr(encoded, 'tolist'): encoded = encoded.tolist() |
|
|
| if len(encoded) == 0: |
| print(f"DEBUG: Empty encoding for '{tstr}'") |
| continue |
|
|
| |
| tid = encoded[0] |
| print(f"DEBUG: Using TID {tid} for '{tstr}'") |
|
|
| if tid in processed_ids: |
| print(f"DEBUG: TID {tid} already processed") |
| continue |
|
|
| logit_val = last_logits[tid].item() |
| rank = get_rank(logit_val, sorted_logits) |
| real_str = tokenizer.decode([tid]) |
|
|
| print(f"DEBUG: Added extra token: {real_str} (ID: {tid}, Rank: {rank})") |
|
|
| topk_data.append({ |
| "rank": rank, |
| "token_id": tid, |
| "token_str": real_str, |
| "logit": logit_val, |
| "is_extra": True |
| }) |
| processed_ids.add(tid) |
| except Exception as e: |
| print(f"DEBUG: Error processing extra str '{tstr}': {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| |
| |
| |
| |
| topk_data.sort(key=lambda x: x['rank']) |
|
|
| if self.input_ids is None: |
| raise ValueError("Input IDs not found. Ensure compute_logits was run.") |
|
|
| |
| |
| input_tokens = [] |
| |
| raw_tokens = tokenizer.convert_ids_to_tokens(self.input_ids[0]) |
|
|
| for t in raw_tokens: |
| |
| if isinstance(t, bytes): |
| try: |
| t = t.decode('utf-8') |
| except: |
| |
| t = str(t) |
|
|
| |
| if isinstance(t, str): |
| |
| t = t.replace('\u2581', ' ') |
| |
| t = t.replace('\u0120', ' ') |
| |
| t = t.replace('\u010A', '\n') |
| |
| t = t.replace('', '') |
|
|
| input_tokens.append(t) |
|
|
| return topk_data, last_logits, input_tokens |
|
|
| def get_target_score(self, backprop_config): |
| """ |
| Calculates the target scalar score (e.g. logit diff) based on config. |
| Returns the score tensor (attached to graph). |
| """ |
| if self.outputs is None: |
| raise ValueError("Model outputs not computed. Call compute_logits first.") |
|
|
| mode = backprop_config.get("mode", "max_logit") |
| last_logits = self.outputs.logits[0, -1, :] |
| sorted_logits, sorted_indices = torch.sort(last_logits, dim=-1, descending=True) |
|
|
| target_token_id = backprop_config.get("target_token_id") |
| if target_token_id is not None: |
| target_logit = last_logits[target_token_id] |
| else: |
| target_logit = sorted_logits[0] |
|
|
| if mode == "max_logit": |
| return target_logit |
|
|
| elif mode == "logit_diff": |
| strategy = backprop_config.get("strategy", "by_topk_avg") |
| top_logit = target_logit |
|
|
| if strategy == "by_ref_token": |
| ref_id = backprop_config.get("ref_token_id") |
| if ref_id is None: |
| raise ValueError("ref_token_id required for strategy 'by_ref_token'") |
| contrast_logit = last_logits[ref_id] |
| target_logit = top_logit - contrast_logit |
|
|
| elif strategy == "demean": |
| target_logit = top_logit - last_logits.mean() |
|
|
| elif strategy == "by_topk_avg": |
| k = backprop_config.get("k", 10) |
| k = min(k, len(sorted_logits)) |
| contrast_logit = sorted_logits[:k].mean() |
| target_logit = top_logit - contrast_logit |
|
|
| return target_logit |
|
|
| def run_backward_pass(self, backprop_config): |
| """ |
| Section 2 Part A: execute backward pass based on configuration. |
| """ |
| target_logit = self.get_target_score(backprop_config) |
|
|
| if target_logit is None: |
| raise ValueError(f"Invalid backprop configuration: {backprop_config}") |
|
|
| |
| model = self.manager.get_model() |
| model.zero_grad() |
|
|
| |
| if hasattr(self, 'input_embeddings') and self.input_embeddings is not None: |
| if self.input_embeddings.grad is not None: |
| self.input_embeddings.grad.zero_() |
|
|
| |
| |
| |
| layers = [] |
| if hasattr(model, "model") and hasattr(model.model, "layers"): |
| layers = model.model.layers |
| elif hasattr(model, "layers"): |
| layers = model.layers |
|
|
| for layer in layers: |
| if hasattr(layer, 'output') and layer.output is not None: |
| if hasattr(layer.output, 'grad') and layer.output.grad is not None: |
| layer.output.grad = None |
| |
| mid_module = getattr(layer, 'post_attention_layernorm', None) |
| if mid_module and hasattr(mid_module, 'mid_activation') and mid_module.mid_activation is not None: |
| if hasattr(mid_module.mid_activation, 'grad') and mid_module.mid_activation.grad is not None: |
| mid_module.mid_activation.grad = None |
|
|
| |
| target_logit.backward(retain_graph=True) |
|
|
| def compute_input_attribution(self, backprop_config): |
| """ |
| Compute input attribution (Input * Gradient). |
| """ |
| |
| |
| |
| |
| |
| default_rule = self.manager.current_lrp_rule or "Attn-LRP" |
| required_rule = backprop_config.get("lrp_rule", default_rule) |
|
|
| if self.manager.current_lrp_rule and self.manager.current_lrp_rule != required_rule: |
| print(f"LRP Rule Mismatch detected (Current: {self.manager.current_lrp_rule}, Requested: {required_rule})") |
| print(f"Reloading model {self.manager.current_model_path} with rule={required_rule}...") |
|
|
| old_rule = self.manager.current_lrp_rule |
|
|
| self.manager.load_model( |
| model_path=self.manager.current_model_path, |
| dtype=self.manager.current_dtype, |
| lrp_rule=required_rule |
| ) |
|
|
| |
| |
| raise RuntimeError( |
| f"LRP rule changed from '{old_rule}' to '{required_rule}'. " |
| "The model has been reloaded. You MUST re-run 'compute_logits()' to rebuild the computation graph with the new rule, " |
| "then call 'compute_input_attribution()' again." |
| ) |
| |
| self.run_backward_pass(backprop_config) |
|
|
| |
| |
| if self.input_embeddings.grad is None: |
| raise RuntimeError("No gradient found on input embeddings. Ensure compute_logits was run.") |
|
|
| relevance = (self.input_embeddings * self.input_embeddings.grad).float().sum(-1).detach().cpu()[0] |
| print(f"Computed input attribution with shape: {relevance}") |
| |
| return relevance.tolist() |
|
|
| def compute_input_attribution_gradient(self, backprop_config): |
| """ |
| Compute input attribution using vanilla gradient method (Input * Gradient). |
| This does NOT require LRP monkey-patching - uses standard PyTorch autograd. |
| The gradient flows through normal attention/MLP without LRP decomposition rules. |
| """ |
| if self.outputs is None: |
| raise RuntimeError("No forward pass found. Run compute_logits first.") |
|
|
| |
| self.run_backward_pass(backprop_config) |
|
|
| |
| if self.input_embeddings.grad is None: |
| raise RuntimeError("No gradient found on input embeddings. Ensure compute_logits was run.") |
|
|
| relevance = (self.input_embeddings * self.input_embeddings.grad).float().sum(-1).detach().cpu()[0] |
| print(f"Computed GRADIENT input attribution with shape: {relevance.shape}") |
| return relevance.tolist() |
|
|
| def compute_perturbation_eval(self, attribution_scores, k_values, target_token_id): |
| """ |
| Evaluate attribution quality by perturbing top-attributed tokens. |
| |
| For each k in k_values: |
| 1. Sort tokens by |attribution score| descending |
| 2. Take top-k token indices |
| 3. Clone input embeddings, zero out those k tokens' embeddings |
| 4. Run forward pass with perturbed embeddings |
| 5. Check if the error token (target_token_id) is still top-1 |
| |
| Args: |
| attribution_scores: list of floats (one per input token) |
| k_values: list of int (e.g., [1, 3, 5, 10]) |
| target_token_id: int - the error token ID to check |
| |
| Returns: |
| list of result dicts for each k |
| """ |
| if self.input_ids is None or self.input_embeddings is None: |
| raise RuntimeError("No forward pass found. Run compute_logits first.") |
|
|
| model = self.manager.get_model() |
| tokenizer = self.manager.get_tokenizer() |
| device = model.device |
|
|
| |
| with torch.no_grad(): |
| original_logits = model( |
| inputs_embeds=self.input_embeddings.detach(), |
| use_cache=False |
| ).logits[0, -1, :] |
| original_top1_id = original_logits.argmax().item() |
| original_target_logit = original_logits[target_token_id].item() |
|
|
| |
| scores = torch.tensor(attribution_scores, dtype=torch.float32) |
| sorted_indices = torch.argsort(scores.abs(), descending=True) |
|
|
| seq_len = self.input_embeddings.shape[1] |
|
|
| results = [] |
| for k in k_values: |
| k_clamped = min(k, seq_len) |
| top_k_indices = sorted_indices[:k_clamped].tolist() |
|
|
| |
| perturbed_token_strs = [] |
| for idx in top_k_indices: |
| if idx < len(self.input_ids[0]): |
| tid = self.input_ids[0][idx].item() |
| perturbed_token_strs.append(tokenizer.decode([tid])) |
| else: |
| perturbed_token_strs.append("?") |
|
|
| |
| perturbed_embeddings = self.input_embeddings.detach().clone() |
| for idx in top_k_indices: |
| perturbed_embeddings[0, idx, :] = 0.0 |
|
|
| |
| with torch.no_grad(): |
| perturbed_logits = model( |
| inputs_embeds=perturbed_embeddings, |
| use_cache=False |
| ).logits[0, -1, :] |
|
|
| perturbed_top1_id = perturbed_logits.argmax().item() |
| perturbed_top1_str = tokenizer.decode([perturbed_top1_id]) |
| perturbed_target_logit = perturbed_logits[target_token_id].item() |
|
|
| |
| error_fixed = (perturbed_top1_id != target_token_id) |
|
|
| |
| logit_change = perturbed_target_logit - original_target_logit |
|
|
| |
| sorted_perturbed, sorted_perturbed_idx = torch.sort(perturbed_logits, descending=True) |
| target_rank_after = (sorted_perturbed_idx == target_token_id).nonzero(as_tuple=True)[0].item() + 1 |
|
|
| results.append({ |
| "k": k, |
| "perturbed_tokens": perturbed_token_strs, |
| "perturbed_indices": top_k_indices, |
| "new_top1_token_id": perturbed_top1_id, |
| "new_top1_token_str": perturbed_top1_str, |
| "error_fixed": error_fixed, |
| "original_target_logit": round(original_target_logit, 4), |
| "perturbed_target_logit": round(perturbed_target_logit, 4), |
| "logit_change": round(logit_change, 4), |
| "target_rank_after": target_rank_after |
| }) |
|
|
| print(f"Perturbation k={k}: error_fixed={error_fixed}, " |
| f"new_top1='{perturbed_top1_str}' (ID={perturbed_top1_id}), " |
| f"logit_change={logit_change:.4f}, target_rank={target_rank_after}") |
|
|
| del perturbed_embeddings |
|
|
| torch.cuda.empty_cache() |
| return results |
|
|
| def compute_perturbation_manual(self, perturb_indices, target_token_id): |
| """ |
| Evaluate attribution by perturbing manually selected token positions. |
| |
| Args: |
| perturb_indices: list of int - token position indices to zero out |
| target_token_id: int - the error token ID to check |
| |
| Returns: |
| dict with perturbation result |
| """ |
| if self.input_ids is None or self.input_embeddings is None: |
| raise RuntimeError("No forward pass found. Run compute_logits first.") |
|
|
| model = self.manager.get_model() |
| tokenizer = self.manager.get_tokenizer() |
| seq_len = self.input_embeddings.shape[1] |
|
|
| |
| valid_indices = [idx for idx in perturb_indices if 0 <= idx < seq_len] |
| if len(valid_indices) == 0: |
| raise ValueError("No valid token indices provided.") |
|
|
| |
| with torch.no_grad(): |
| original_logits = model( |
| inputs_embeds=self.input_embeddings.detach(), |
| use_cache=False |
| ).logits[0, -1, :] |
| original_top1_id = original_logits.argmax().item() |
| original_target_logit = original_logits[target_token_id].item() |
|
|
| |
| perturbed_token_strs = [] |
| for idx in valid_indices: |
| if idx < len(self.input_ids[0]): |
| tid = self.input_ids[0][idx].item() |
| perturbed_token_strs.append(tokenizer.decode([tid])) |
| else: |
| perturbed_token_strs.append("?") |
|
|
| |
| perturbed_embeddings = self.input_embeddings.detach().clone() |
| for idx in valid_indices: |
| perturbed_embeddings[0, idx, :] = 0.0 |
|
|
| |
| with torch.no_grad(): |
| perturbed_logits = model( |
| inputs_embeds=perturbed_embeddings, |
| use_cache=False |
| ).logits[0, -1, :] |
|
|
| perturbed_top1_id = perturbed_logits.argmax().item() |
| perturbed_top1_str = tokenizer.decode([perturbed_top1_id]) |
| perturbed_target_logit = perturbed_logits[target_token_id].item() |
|
|
| |
| error_fixed = (perturbed_top1_id != target_token_id) |
|
|
| |
| logit_change = perturbed_target_logit - original_target_logit |
|
|
| |
| sorted_perturbed, sorted_perturbed_idx = torch.sort(perturbed_logits, descending=True) |
| target_rank_after = (sorted_perturbed_idx == target_token_id).nonzero(as_tuple=True)[0].item() + 1 |
|
|
| result = { |
| "k": len(valid_indices), |
| "perturbed_tokens": perturbed_token_strs, |
| "perturbed_indices": valid_indices, |
| "new_top1_token_id": perturbed_top1_id, |
| "new_top1_token_str": perturbed_top1_str, |
| "error_fixed": error_fixed, |
| "original_target_logit": round(original_target_logit, 4), |
| "perturbed_target_logit": round(perturbed_target_logit, 4), |
| "logit_change": round(logit_change, 4), |
| "target_rank_after": target_rank_after |
| } |
|
|
| print(f"Manual Perturbation ({len(valid_indices)} tokens): error_fixed={error_fixed}, " |
| f"new_top1='{perturbed_top1_str}' (ID={perturbed_top1_id}), " |
| f"logit_change={logit_change:.4f}, target_rank={target_rank_after}") |
|
|
| del perturbed_embeddings |
| torch.cuda.empty_cache() |
| return result |
|
|
| def compute_connection_matrix_gen(self, source, target, node_threshold=None): |
| """ |
| Section 2 Part B: Compute Token-to-Token interaction matrix between two nodes. |
| Generator version that yields progress. |
| source, target: int (layer idx) or tuple (layer_idx, 'mid'/'post') |
| """ |
| source_layer_idx, source_type = self._parse_node(source) |
| target_layer_idx, target_type = self._parse_node(target) |
|
|
| model = self.manager.get_model() |
| layers = model.model.layers |
|
|
| target_layer = layers[target_layer_idx] |
|
|
| |
| if source_layer_idx == -1: |
| source_tensor = self.input_embeddings |
| else: |
| layer = layers[source_layer_idx] |
| if source_type == 'mid': |
| |
| mid_mod = self.manager.decomposer.get_mid_activation_module(layer) |
| if not mid_mod: |
| raise ValueError(f"Decomposer could not identify mid-activation module for layer {source_layer_idx}") |
|
|
| source_tensor = getattr(mid_mod, 'mid_activation', None) |
| if source_tensor is None: |
| raise ValueError(f"Mid activation for layer {source_layer_idx} not captured. Enable capture_mid in compute_logits.") |
| else: |
| source_tensor = layer.output |
|
|
| |
| if target_type == 'mid': |
| |
| mid_mod = self.manager.decomposer.get_mid_activation_module(target_layer) |
| if not mid_mod: |
| raise ValueError(f"Decomposer could not identify mid-activation module for target {target_layer_idx}") |
|
|
| target_tensor = getattr(mid_mod, 'mid_activation', None) |
| if target_tensor is None: |
| raise ValueError(f"Mid activation for target {target_layer_idx} not captured.") |
| else: |
| target_tensor = target_layer.output |
|
|
| target_grad = target_tensor.grad |
|
|
| if target_grad is None: |
| print(f"WARNING: target_grad is None for target layer {target_layer_idx} (type={target_type}). " |
| f"This will result in an all-zero interaction matrix. " |
| f"Ensure backward pass was run and hooks captured activations correctly.") |
|
|
| |
| was_checkpointing = model.is_gradient_checkpointing |
| if was_checkpointing: |
| model.gradient_checkpointing_disable() |
|
|
| try: |
| |
| target_layer_input = source_tensor.detach() |
| batch_size, seq_len, hidden_dim = target_layer_input.shape |
|
|
| |
| if target_grad is not None: |
| real_target_rel = (target_tensor * target_grad).sum(dim=-1)[0] |
| else: |
| real_target_rel = torch.zeros(seq_len, device=model.device) |
|
|
| |
| total_params = model.num_parameters() |
| if node_threshold is None: node_threshold = 0.01 |
|
|
| if node_threshold > 0: |
| indices_to_compute = torch.nonzero(real_target_rel.abs() > node_threshold).squeeze(-1).tolist() |
| if isinstance(indices_to_compute, int): indices_to_compute = [indices_to_compute] |
| print(f"DEBUG: Node Threshold {node_threshold}. Computing for {len(indices_to_compute)}/{seq_len} nodes.") |
| else: |
| indices_to_compute = list(range(seq_len)) |
|
|
| |
| position_ids = torch.arange(0, seq_len, dtype=torch.long, device=model.device).unsqueeze(0) |
|
|
| |
| ops = [] |
| if source_layer_idx != -1: |
| if source_type == 'mid': |
| ops.append(('part2', layers[source_layer_idx])) |
|
|
| |
| for i in range(source_layer_idx + 1, target_layer_idx): |
| ops.append(('part1', layers[i])) |
| ops.append(('part2', layers[i])) |
|
|
| |
| if target_layer_idx > source_layer_idx: |
| ops.append(('part1', layers[target_layer_idx])) |
| if target_type == 'post': |
| ops.append(('part2', layers[target_layer_idx])) |
| elif target_layer_idx == source_layer_idx: |
| pass |
| elif source_layer_idx == -1: |
| |
| |
| |
| ops.append(('part1', layers[target_layer_idx])) |
| if target_type == 'post': |
| ops.append(('part2', layers[target_layer_idx])) |
|
|
| |
| |
| rotary_emb = None |
| if hasattr(model.model, 'rotary_emb'): |
| rotary_emb = model.model.rotary_emb(target_layer_input, position_ids) |
| elif hasattr(model.model, 'rotary_embs') and 'full_attention' in model.model.rotary_embs: |
| rotary_emb = model.model.rotary_embs['full_attention'](target_layer_input, position_ids) |
| elif hasattr(model.model, 'rotary_embs') and len(model.model.rotary_embs) > 0: |
| rotary_emb = list(model.model.rotary_embs.values())[0](target_layer_input, position_ids) |
|
|
| |
| current_dtype = target_layer_input.dtype |
| BATCH_CHUNK_SIZE = get_batch_chunk_size(total_params, current_dtype) |
| token_interaction = torch.zeros(seq_len, seq_len, device=model.device) |
| target_grad_full = target_grad |
|
|
| total_items = len(indices_to_compute) |
| print(f"DEBUG: BATCH_CHUNK_SIZE={BATCH_CHUNK_SIZE}, total_items={total_items}, seq_len={seq_len}, params={total_params/1e9:.2f}B, dtype={current_dtype}") |
|
|
| for i in range(0, total_items, BATCH_CHUNK_SIZE): |
| yield {"type": "progress", "current": i, "total": total_items} |
|
|
| chunk_indices = indices_to_compute[i : i + BATCH_CHUNK_SIZE] |
| current_batch_size = len(chunk_indices) |
|
|
| expanded_input = target_layer_input.expand(current_batch_size, seq_len, hidden_dim).clone().requires_grad_(True) |
|
|
| |
| hidden_states = expanded_input |
| for op_type, layer_mod in ops: |
| if op_type == 'part1': |
| hidden_states = self._forward_part1(layer_mod, hidden_states, position_embeddings=rotary_emb) |
| else: |
| hidden_states = self._forward_part2(layer_mod, hidden_states) |
|
|
| reconstructed_output = hidden_states |
|
|
| |
| grad_output_chunk = torch.zeros(current_batch_size, seq_len, hidden_dim, dtype=reconstructed_output.dtype, device=model.device) |
|
|
| for batch_idx, global_idx in enumerate(chunk_indices): |
| if target_grad_full is not None: |
| grad_output_chunk[batch_idx, global_idx, :] = target_grad_full[0, global_idx, :] |
|
|
| grad_input = torch.autograd.grad(outputs=reconstructed_output, inputs=expanded_input, grad_outputs=grad_output_chunk, retain_graph=False)[0] |
|
|
| chunk_relevance = (grad_input * expanded_input).sum(dim=-1) |
| token_interaction[chunk_indices, :] = chunk_relevance.detach().to(token_interaction.dtype) |
|
|
| del expanded_input, hidden_states, reconstructed_output, grad_output_chunk, grad_input, chunk_relevance |
| torch.cuda.empty_cache() |
|
|
| |
| if source_layer_idx == -1: |
| if self.input_embeddings.grad is not None: |
| real_source_rel = (self.input_embeddings * self.input_embeddings.grad).sum(dim=-1)[0] |
| else: |
| real_source_rel = torch.zeros(seq_len, device=model.device) |
| else: |
| if source_tensor.grad is not None: |
| real_source_rel = (source_tensor * source_tensor.grad).sum(dim=-1)[0] |
| else: |
| real_source_rel = torch.zeros(seq_len, device=model.device) |
|
|
| yield { |
| "type": "result", |
| "payload": { |
| "matrix": token_interaction.detach().float().cpu().numpy(), |
| "real_target_rel": real_target_rel.detach().float().cpu().numpy(), |
| "real_source_rel": real_source_rel.detach().float().cpu().numpy() |
| } |
| } |
|
|
| finally: |
| if was_checkpointing: |
| model.gradient_checkpointing_enable() |
|
|
| def compute_connection_matrix(self, source, target): |
| for item in self.compute_connection_matrix_gen(source, target): |
| if item.get("type") == "result": |
| return item["payload"] |
| return None |
|
|