abpt / src /model /future_influence.py
Search
feat: add src/ module for script imports
8125804
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
class FutureInfluenceScorer(nn.Module):
def forward(
self,
hidden: torch.Tensor,
logits: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
future_window: int = 16,
) -> dict[str, Any]:
if not hidden.requires_grad:
raise ValueError("hidden must require gradients for future influence scoring")
if input_ids.size(1) < 2:
zero_scores = torch.zeros(input_ids.shape, device=input_ids.device, dtype=hidden.dtype)
return {
"scores": zero_scores,
"raw_scores": zero_scores,
"loss": 0.0,
"target_window": 0,
}
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
if attention_mask is None:
shift_mask = torch.ones_like(shift_labels, dtype=hidden.dtype)
else:
shift_mask = attention_mask[:, 1:].to(hidden.dtype)
window = max(1, min(int(future_window), shift_labels.size(1)))
target_logits = shift_logits[:, -window:, :]
target_labels = shift_labels[:, -window:]
target_mask = shift_mask[:, -window:]
token_loss = F.cross_entropy(
target_logits.reshape(-1, target_logits.size(-1)),
target_labels.reshape(-1),
reduction="none",
).view_as(target_labels)
masked_loss = token_loss * target_mask
loss = masked_loss.sum() / target_mask.sum().clamp_min(1.0)
grad_hidden = torch.autograd.grad(loss, hidden, retain_graph=False, create_graph=False)[0]
raw_scores = grad_hidden.norm(dim=-1)
denom = raw_scores.amax(dim=-1, keepdim=True).clamp_min(1e-6)
norm_scores = raw_scores / denom
if attention_mask is not None:
norm_scores = norm_scores * attention_mask.to(norm_scores.dtype)
raw_scores = raw_scores * attention_mask.to(raw_scores.dtype)
return {
"scores": norm_scores.detach(),
"raw_scores": raw_scores.detach(),
"loss": float(loss.detach().item()),
"target_window": window,
}