""" JacobianScopes: Fisher, Temperature, and Semantic scope implementations. Uses JCBScope_utils.customize_forward_pass for the forward interface. """ from __future__ import annotations import numpy as np import torch import torch.nn.functional as F try: from tqdm import tqdm except ImportError: tqdm = None import JCBScope_utils def fisher_hidden_from_logits_and_W(logits_t, W, chunk_size=8192): """ Fisher information matrix at hidden state for a single position. Fx [d,d] = W^T (Diag(p) - p p^T) W. Args: logits_t: [V] logits at the target position W: [V, d] lm_head weight chunk_size: chunk size for vocab loop to control memory Returns: Fx: [d, d] symmetric positive semi-definite matrix """ if logits_t.ndim != 1: raise ValueError(f"logits_t must be 1D [V], got {tuple(logits_t.shape)}") if logits_t.numel() != W.shape[0]: raise ValueError(f"logits_t has {logits_t.numel()} entries but W has vocab {W.shape[0]}") p = F.softmax(logits_t, dim=-1).to(dtype=W.dtype) V, d = W.shape S = torch.zeros(d, d, device=W.device, dtype=W.dtype) mu = torch.zeros(d, device=W.device, dtype=W.dtype) for i in range(0, V, chunk_size): Wi = W[i : i + chunk_size] pi = p[i : i + chunk_size].unsqueeze(1) Ww = Wi * pi S += Wi.T @ Ww mu += Ww.sum(dim=0) return S - torch.outer(mu, mu) def fisher_scope_scores( forward_pass, residual, loss_position, lm_head, method="full", batch_size=16, k=1, n_hutchinson_samples=8, eps_finite_diff=1e-5, progress=False, ): """ Compute Fisher scope influence scores per token: tr(J^T Fx J) per token. Args: forward_pass: from JCBScope_utils.customize_forward_pass residual: nn.Parameter [n_tokens, d_model] loss_position: int or tensor, position for loss lm_head: model lm_head (for W and logits) n_tokens: len(grad_idx) d_model: residual.shape[-1] method: 'full' | 'low_rank' | 'finite_diff' batch_size: for full Jacobian backward batches k: top-k for low_rank n_hutchinson_samples: for finite_diff eps_finite_diff: finite-diff step for finite_diff progress: if True, use tqdm for finite_diff loop Returns: (scores, logits) for method 'full' or 'low_rank'; (scores,) for 'finite_diff'. scores: np.ndarray [n_tokens], float32 """ W = lm_head.weight.to(residual.device, dtype=residual.dtype) n_tokens = residual.shape[0] d_model = residual.shape[1] if method == "full": projection_probes = torch.eye(d_model, d_model, device=residual.device, dtype=residual.dtype) losses, logits = forward_pass(loss_position=loss_position, projection_probe=projection_probes) num_losses = losses.numel() grads_list = [] for i in range(0, num_losses, batch_size): end = min(i + batch_size, num_losses) eye_batch = torch.eye(num_losses, device=losses.device, dtype=losses.dtype)[i:end] g = torch.autograd.grad( outputs=losses, inputs=residual, grad_outputs=eye_batch, is_grads_batched=True, retain_graph=(end < num_losses), )[0] grads_list.append(g) grads = torch.cat(grads_list, dim=0) del grads_list, losses J_all = grads.detach().swapaxes(0, 1) del grads logits_t = logits[loss_position].to(residual.device) Fx = fisher_hidden_from_logits_and_W(logits_t, W) tmp = Fx.unsqueeze(0) @ J_all F_all = J_all.transpose(1, 2) @ tmp scores = np.array([torch.trace(F_all[i]).item() for i in range(n_tokens)], dtype=np.float32) return scores, logits if method == "low_rank": _, logits = forward_pass(loss_position=loss_position, hidden_norm_as_loss=True) del _ logits_t = logits[loss_position].to(residual.device) with torch.no_grad(): Fx = fisher_hidden_from_logits_and_W(logits_t, W) orig_device = Fx.device Fx_cpu = Fx.float().contiguous().cpu() U_full, eigvals, _ = torch.linalg.svd(Fx_cpu) eigvals = eigvals.numpy() U_full = U_full.to(orig_device) k_actual = min(k, len(eigvals)) idx_top = np.argsort(eigvals)[::-1][:k_actual].copy() U_k = U_full[:, idx_top].clone().to(residual.device) S_k = torch.tensor(eigvals[idx_top].copy(), device=residual.device, dtype=W.dtype) projection_probes_lowrank = U_k.T losses_lr, _ = forward_pass(loss_position=loss_position, projection_probe=projection_probes_lowrank) eye_k = torch.eye(k_actual, device=losses_lr.device, dtype=losses_lr.dtype) grads_lr = torch.autograd.grad( outputs=losses_lr, inputs=residual, grad_outputs=eye_k, is_grads_batched=True )[0] grads_lr = grads_lr.contiguous() S_k_diag = torch.diag(S_k).contiguous().to(grads_lr.device) scores = np.zeros(n_tokens, dtype=np.float32) for tau in range(grads_lr.shape[1]): UkT_J = grads_lr[:, tau, :] M = UkT_J @ UkT_J.T # (U_k^T J)(J^T U_k) scores[tau] = (S_k_diag @ M).trace().item() return scores, logits if method == "finite_diff": _, logits = forward_pass(loss_position=loss_position) del _ logits_t = logits[loss_position].to(residual.device) with torch.no_grad(): Fx = fisher_hidden_from_logits_and_W(logits_t, W) torch.manual_seed(42) hutch_probes = ( torch.randint(0, 2, (n_hutchinson_samples, d_model), device=residual.device) * 2 - 1 ).to(residual.dtype) hidden_base = forward_pass(loss_position=loss_position, return_hidden=True) hutchinson_Jz_all = torch.zeros( residual.shape[0], n_hutchinson_samples, d_model, device=residual.device, dtype=residual.dtype ) iterator = range(residual.shape[0]) if progress and tqdm is not None: iterator = tqdm(iterator, desc="Finite-diff JVP") for tau in iterator: perturb = torch.zeros( n_hutchinson_samples, residual.shape[0], d_model, device=residual.device, dtype=residual.dtype ) perturb[:, tau, :] = hutch_probes r_batch = residual.unsqueeze(0) + eps_finite_diff * perturb hidden_pert = forward_pass(loss_position=loss_position, residual_batch=r_batch, return_hidden=True) Jz = (hidden_pert - hidden_base.unsqueeze(0)) / eps_finite_diff hutchinson_Jz_all[tau] = Jz.detach() g = hutchinson_Jz_all.to(Fx.device, dtype=Fx.dtype) FxJz = g @ Fx quad = (FxJz * g).sum(dim=-1) scores = quad.mean(dim=1).cpu().numpy().astype(np.float32) return scores,logits raise ValueError(f"method must be 'full', 'low_rank', or 'finite_diff', got {method!r}") def temperature_scope_scores(forward_pass, residual, loss_position): """ Temperature scope: gradient of hidden-norm loss w.r.t. residual; score = grad norm per token. Returns: scores: np.ndarray [n_tokens], float32 """ loss, logits = forward_pass(loss_position=loss_position, hidden_norm_as_loss=True) grads = torch.autograd.grad(loss, residual, retain_graph=False)[0] scores = grads.norm(dim=-1).squeeze().cpu().numpy().astype(np.float32) if scores.ndim > 1: scores = scores.squeeze() return scores, logits def semantic_scope_scores( forward_pass, residual, loss_position, path_integral=False, presence_ratios=None, grad_idx=None, return_grads_per_step=False, target_id=None ): """ Semantic scope: single-pass (hidden_norm_as_loss=False, unnormalized_logits) or path-integrated. When path_integral=False: gradient norm per token. When path_integral=True: Path_integrated_grad = mean(grads) * (x_final - x_initial) at grad_idx, scores = norm. Args: forward_pass: from JCBScope_utils.customize_forward_pass residual: nn.Parameter [n_tokens, d_model] loss_position: int or tensor path_integral: if True, use path integration as in Path_Integrated_Semantic_Scope.ipynb presence_ratios: for path_integral; default np.linspace(0.01, 1, 100) grad_idx: indices of token positions for path_integral slice; if None use range(n_tokens) return_grads_per_step: if True and path_integral, also return (grads_per_step, input_embeds_per_step) Returns: scores: np.ndarray [n_tokens], float32 If return_grads_per_step and path_integral: (scores, grads_per_step, input_embeds_per_step) """ if not path_integral: loss, logits = forward_pass( loss_position=loss_position, hidden_norm_as_loss=False, unnormalized_logits=True, target_id=target_id ) grads = torch.autograd.grad(loss, residual, retain_graph=False)[0] scores = grads.norm(dim=-1).squeeze().cpu().numpy().astype(np.float32) if scores.ndim > 1: scores = scores.squeeze() return scores, logits if grad_idx is None: grad_idx = list(range(residual.shape[0])) presence_ratios = presence_ratios if presence_ratios is not None else np.linspace(0.01, 1.0, 10) grads_per_step = [] input_embeds_per_step = [] for presence_ratio in presence_ratios: loss, logits, input_embeds = forward_pass( loss_position=loss_position, hidden_norm_as_loss=False, unnormalized_logits=True, return_input_embeds=True, alpha=float(presence_ratio), ) input_embeds_per_step.append(input_embeds.detach().cpu().clone()) residual_grad = torch.autograd.grad(loss, residual, retain_graph=False)[0] / presence_ratio grads_per_step.append(residual_grad.detach().cpu().clone()) grad_stack = torch.stack(grads_per_step) input_final = input_embeds_per_step[-1][0, grad_idx, :] input_initial = input_embeds_per_step[0][0, grad_idx, :] path_integrated_grad = grad_stack.mean(dim=0) * (input_final - input_initial) scores = path_integrated_grad.norm(dim=-1).numpy().astype(np.float32) if return_grads_per_step: return scores, grads_per_step, input_embeds_per_step return scores def gradient_x_input_scores(forward_pass, residual, loss_position, embedding_layer, input_ids, grad_idx): """ Gradient times input (grad * input_embeds) norm per token. Returns: scores: np.ndarray [n_tokens], float32 """ loss, logits = forward_pass( loss_position=loss_position, hidden_norm_as_loss=False, unnormalized_logits=False, ) grads = torch.autograd.grad(loss, residual, retain_graph=False)[0] with torch.no_grad(): token_embeds = JCBScope_utils.embedding_lookup(input_ids[0, grad_idx], embedding_layer) scores = (grads * token_embeds.to(grads.device)).norm(dim=-1).squeeze().cpu().numpy().astype(np.float32) if scores.ndim > 1: scores = scores.squeeze() return scores, logits def setup_scope_context(model, tokenizer, string, front_pad=2, back_pad=0, front_strip=0, eos_token_id=None): """ Build input_ids, attention_mask, grad_idx, decoded_tokens, residual, presence, forward_pass, d_model, embed_device for use with scope functions. Caller must still set model.eval() and pass loss_position. Returns: dict with keys: input_ids, attention_mask, grad_idx, decoded_tokens, residual, presence, forward_pass, d_model, embed_device """ input_ids_list = tokenizer(string, add_special_tokens=False)["input_ids"] if eos_token_id is not None: input_ids_list += [eos_token_id] * back_pad decoded_tokens = tokenizer.batch_decode([[tid] for tid in input_ids_list], skip_special_tokens=True) grad_idx = list(range(front_pad, len(decoded_tokens)))[front_strip:] embedding_layer = model.get_input_embeddings() embed_device = embedding_layer.weight.device d_model = embedding_layer.embedding_dim input_ids = torch.tensor([input_ids_list], dtype=torch.long).to(embed_device) attention_mask = torch.ones_like(input_ids, device=embed_device) residual = torch.nn.Parameter(torch.zeros(len(grad_idx), d_model, device=embed_device)) presence = torch.ones(len(decoded_tokens), 1, device=embed_device) forward_pass = JCBScope_utils.customize_forward_pass( model, residual, presence, input_ids, grad_idx, attention_mask ) return { "input_ids": input_ids, "attention_mask": attention_mask, "grad_idx": grad_idx, "decoded_tokens": decoded_tokens, "residual": residual, "presence": presence, "forward_pass": forward_pass, "d_model": d_model, "embed_device": embed_device, }