JacobianScopes / src /JacobianScopes.py
Typony's picture
Update src/JacobianScopes.py
dba2c2b verified
"""
JacobianScopes: Fisher, Temperature, and Semantic scope implementations.
Uses JCBScope_utils.customize_forward_pass for the forward interface.
"""
from __future__ import annotations
import numpy as np
import torch
import torch.nn.functional as F
try:
from tqdm import tqdm
except ImportError:
tqdm = None
import JCBScope_utils
def fisher_hidden_from_logits_and_W(logits_t, W, chunk_size=8192):
"""
Fisher information matrix at hidden state for a single position.
Fx [d,d] = W^T (Diag(p) - p p^T) W.
Args:
logits_t: [V] logits at the target position
W: [V, d] lm_head weight
chunk_size: chunk size for vocab loop to control memory
Returns:
Fx: [d, d] symmetric positive semi-definite matrix
"""
if logits_t.ndim != 1:
raise ValueError(f"logits_t must be 1D [V], got {tuple(logits_t.shape)}")
if logits_t.numel() != W.shape[0]:
raise ValueError(f"logits_t has {logits_t.numel()} entries but W has vocab {W.shape[0]}")
p = F.softmax(logits_t, dim=-1).to(dtype=W.dtype)
V, d = W.shape
S = torch.zeros(d, d, device=W.device, dtype=W.dtype)
mu = torch.zeros(d, device=W.device, dtype=W.dtype)
for i in range(0, V, chunk_size):
Wi = W[i : i + chunk_size]
pi = p[i : i + chunk_size].unsqueeze(1)
Ww = Wi * pi
S += Wi.T @ Ww
mu += Ww.sum(dim=0)
return S - torch.outer(mu, mu)
def fisher_scope_scores(
forward_pass,
residual,
loss_position,
lm_head,
method="full",
batch_size=16,
k=1,
n_hutchinson_samples=8,
eps_finite_diff=1e-5,
progress=False,
):
"""
Compute Fisher scope influence scores per token: tr(J^T Fx J) per token.
Args:
forward_pass: from JCBScope_utils.customize_forward_pass
residual: nn.Parameter [n_tokens, d_model]
loss_position: int or tensor, position for loss
lm_head: model lm_head (for W and logits)
n_tokens: len(grad_idx)
d_model: residual.shape[-1]
method: 'full' | 'low_rank' | 'finite_diff'
batch_size: for full Jacobian backward batches
k: top-k for low_rank
n_hutchinson_samples: for finite_diff
eps_finite_diff: finite-diff step for finite_diff
progress: if True, use tqdm for finite_diff loop
Returns:
(scores, logits) for method 'full' or 'low_rank'; (scores,) for 'finite_diff'.
scores: np.ndarray [n_tokens], float32
"""
W = lm_head.weight.to(residual.device, dtype=residual.dtype)
n_tokens = residual.shape[0]
d_model = residual.shape[1]
if method == "full":
projection_probes = torch.eye(d_model, d_model, device=residual.device, dtype=residual.dtype)
losses, logits = forward_pass(loss_position=loss_position, projection_probe=projection_probes)
num_losses = losses.numel()
grads_list = []
for i in range(0, num_losses, batch_size):
end = min(i + batch_size, num_losses)
eye_batch = torch.eye(num_losses, device=losses.device, dtype=losses.dtype)[i:end]
g = torch.autograd.grad(
outputs=losses,
inputs=residual,
grad_outputs=eye_batch,
is_grads_batched=True,
retain_graph=(end < num_losses),
)[0]
grads_list.append(g)
grads = torch.cat(grads_list, dim=0)
del grads_list, losses
J_all = grads.detach().swapaxes(0, 1)
del grads
logits_t = logits[loss_position].to(residual.device)
Fx = fisher_hidden_from_logits_and_W(logits_t, W)
tmp = Fx.unsqueeze(0) @ J_all
F_all = J_all.transpose(1, 2) @ tmp
scores = np.array([torch.trace(F_all[i]).item() for i in range(n_tokens)], dtype=np.float32)
return scores, logits
if method == "low_rank":
_, logits = forward_pass(loss_position=loss_position, hidden_norm_as_loss=True)
del _
logits_t = logits[loss_position].to(residual.device)
with torch.no_grad():
Fx = fisher_hidden_from_logits_and_W(logits_t, W)
orig_device = Fx.device
Fx_cpu = Fx.float().contiguous().cpu()
U_full, eigvals, _ = torch.linalg.svd(Fx_cpu)
eigvals = eigvals.numpy()
U_full = U_full.to(orig_device)
k_actual = min(k, len(eigvals))
idx_top = np.argsort(eigvals)[::-1][:k_actual].copy()
U_k = U_full[:, idx_top].clone().to(residual.device)
S_k = torch.tensor(eigvals[idx_top].copy(), device=residual.device, dtype=W.dtype)
projection_probes_lowrank = U_k.T
losses_lr, _ = forward_pass(loss_position=loss_position, projection_probe=projection_probes_lowrank)
eye_k = torch.eye(k_actual, device=losses_lr.device, dtype=losses_lr.dtype)
grads_lr = torch.autograd.grad(
outputs=losses_lr, inputs=residual, grad_outputs=eye_k, is_grads_batched=True
)[0]
grads_lr = grads_lr.contiguous()
S_k_diag = torch.diag(S_k).contiguous().to(grads_lr.device)
scores = np.zeros(n_tokens, dtype=np.float32)
for tau in range(grads_lr.shape[1]):
UkT_J = grads_lr[:, tau, :]
M = UkT_J @ UkT_J.T # (U_k^T J)(J^T U_k)
scores[tau] = (S_k_diag @ M).trace().item()
return scores, logits
if method == "finite_diff":
_, logits = forward_pass(loss_position=loss_position)
del _
logits_t = logits[loss_position].to(residual.device)
with torch.no_grad():
Fx = fisher_hidden_from_logits_and_W(logits_t, W)
torch.manual_seed(42)
hutch_probes = (
torch.randint(0, 2, (n_hutchinson_samples, d_model), device=residual.device) * 2 - 1
).to(residual.dtype)
hidden_base = forward_pass(loss_position=loss_position, return_hidden=True)
hutchinson_Jz_all = torch.zeros(
residual.shape[0], n_hutchinson_samples, d_model, device=residual.device, dtype=residual.dtype
)
iterator = range(residual.shape[0])
if progress and tqdm is not None:
iterator = tqdm(iterator, desc="Finite-diff JVP")
for tau in iterator:
perturb = torch.zeros(
n_hutchinson_samples, residual.shape[0], d_model, device=residual.device, dtype=residual.dtype
)
perturb[:, tau, :] = hutch_probes
r_batch = residual.unsqueeze(0) + eps_finite_diff * perturb
hidden_pert = forward_pass(loss_position=loss_position, residual_batch=r_batch, return_hidden=True)
Jz = (hidden_pert - hidden_base.unsqueeze(0)) / eps_finite_diff
hutchinson_Jz_all[tau] = Jz.detach()
g = hutchinson_Jz_all.to(Fx.device, dtype=Fx.dtype)
FxJz = g @ Fx
quad = (FxJz * g).sum(dim=-1)
scores = quad.mean(dim=1).cpu().numpy().astype(np.float32)
return scores,logits
raise ValueError(f"method must be 'full', 'low_rank', or 'finite_diff', got {method!r}")
def temperature_scope_scores(forward_pass, residual, loss_position):
"""
Temperature scope: gradient of hidden-norm loss w.r.t. residual; score = grad norm per token.
Returns:
scores: np.ndarray [n_tokens], float32
"""
loss, logits = forward_pass(loss_position=loss_position, hidden_norm_as_loss=True)
grads = torch.autograd.grad(loss, residual, retain_graph=False)[0]
scores = grads.norm(dim=-1).squeeze().cpu().numpy().astype(np.float32)
if scores.ndim > 1:
scores = scores.squeeze()
return scores, logits
def semantic_scope_scores(
forward_pass,
residual,
loss_position,
path_integral=False,
presence_ratios=None,
grad_idx=None,
return_grads_per_step=False,
target_id=None
):
"""
Semantic scope: single-pass (hidden_norm_as_loss=False, unnormalized_logits) or path-integrated.
When path_integral=False: gradient norm per token.
When path_integral=True: Path_integrated_grad = mean(grads) * (x_final - x_initial) at grad_idx, scores = norm.
Args:
forward_pass: from JCBScope_utils.customize_forward_pass
residual: nn.Parameter [n_tokens, d_model]
loss_position: int or tensor
path_integral: if True, use path integration as in Path_Integrated_Semantic_Scope.ipynb
presence_ratios: for path_integral; default np.linspace(0.01, 1, 100)
grad_idx: indices of token positions for path_integral slice; if None use range(n_tokens)
return_grads_per_step: if True and path_integral, also return (grads_per_step, input_embeds_per_step)
Returns:
scores: np.ndarray [n_tokens], float32
If return_grads_per_step and path_integral: (scores, grads_per_step, input_embeds_per_step)
"""
if not path_integral:
loss, logits = forward_pass(
loss_position=loss_position,
hidden_norm_as_loss=False,
unnormalized_logits=True,
target_id=target_id
)
grads = torch.autograd.grad(loss, residual, retain_graph=False)[0]
scores = grads.norm(dim=-1).squeeze().cpu().numpy().astype(np.float32)
if scores.ndim > 1:
scores = scores.squeeze()
return scores, logits
if grad_idx is None:
grad_idx = list(range(residual.shape[0]))
presence_ratios = presence_ratios if presence_ratios is not None else np.linspace(0.01, 1.0, 10)
grads_per_step = []
input_embeds_per_step = []
for presence_ratio in presence_ratios:
loss, logits, input_embeds = forward_pass(
loss_position=loss_position,
hidden_norm_as_loss=False,
unnormalized_logits=True,
return_input_embeds=True,
alpha=float(presence_ratio),
)
input_embeds_per_step.append(input_embeds.detach().cpu().clone())
residual_grad = torch.autograd.grad(loss, residual, retain_graph=False)[0] / presence_ratio
grads_per_step.append(residual_grad.detach().cpu().clone())
grad_stack = torch.stack(grads_per_step)
input_final = input_embeds_per_step[-1][0, grad_idx, :]
input_initial = input_embeds_per_step[0][0, grad_idx, :]
path_integrated_grad = grad_stack.mean(dim=0) * (input_final - input_initial)
scores = path_integrated_grad.norm(dim=-1).numpy().astype(np.float32)
if return_grads_per_step:
return scores, grads_per_step, input_embeds_per_step
return scores
def gradient_x_input_scores(forward_pass, residual, loss_position, embedding_layer, input_ids, grad_idx):
"""
Gradient times input (grad * input_embeds) norm per token.
Returns:
scores: np.ndarray [n_tokens], float32
"""
loss, logits = forward_pass(
loss_position=loss_position,
hidden_norm_as_loss=False,
unnormalized_logits=False,
)
grads = torch.autograd.grad(loss, residual, retain_graph=False)[0]
with torch.no_grad():
token_embeds = JCBScope_utils.embedding_lookup(input_ids[0, grad_idx], embedding_layer)
scores = (grads * token_embeds.to(grads.device)).norm(dim=-1).squeeze().cpu().numpy().astype(np.float32)
if scores.ndim > 1:
scores = scores.squeeze()
return scores, logits
def setup_scope_context(model, tokenizer, string, front_pad=2, back_pad=0, front_strip=0, eos_token_id=None):
"""
Build input_ids, attention_mask, grad_idx, decoded_tokens, residual, presence, forward_pass, d_model, embed_device
for use with scope functions. Caller must still set model.eval() and pass loss_position.
Returns:
dict with keys: input_ids, attention_mask, grad_idx, decoded_tokens, residual, presence,
forward_pass, d_model, embed_device
"""
input_ids_list = tokenizer(string, add_special_tokens=False)["input_ids"]
if eos_token_id is not None:
input_ids_list += [eos_token_id] * back_pad
decoded_tokens = tokenizer.batch_decode([[tid] for tid in input_ids_list], skip_special_tokens=True)
grad_idx = list(range(front_pad, len(decoded_tokens)))[front_strip:]
embedding_layer = model.get_input_embeddings()
embed_device = embedding_layer.weight.device
d_model = embedding_layer.embedding_dim
input_ids = torch.tensor([input_ids_list], dtype=torch.long).to(embed_device)
attention_mask = torch.ones_like(input_ids, device=embed_device)
residual = torch.nn.Parameter(torch.zeros(len(grad_idx), d_model, device=embed_device))
presence = torch.ones(len(decoded_tokens), 1, device=embed_device)
forward_pass = JCBScope_utils.customize_forward_pass(
model, residual, presence, input_ids, grad_idx, attention_mask
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"grad_idx": grad_idx,
"decoded_tokens": decoded_tokens,
"residual": residual,
"presence": presence,
"forward_pass": forward_pass,
"d_model": d_model,
"embed_device": embed_device,
}