Spaces:
Running on Zero
Running on Zero
File size: 2,310 Bytes
8125804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
class FutureInfluenceScorer(nn.Module):
def forward(
self,
hidden: torch.Tensor,
logits: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
future_window: int = 16,
) -> dict[str, Any]:
if not hidden.requires_grad:
raise ValueError("hidden must require gradients for future influence scoring")
if input_ids.size(1) < 2:
zero_scores = torch.zeros(input_ids.shape, device=input_ids.device, dtype=hidden.dtype)
return {
"scores": zero_scores,
"raw_scores": zero_scores,
"loss": 0.0,
"target_window": 0,
}
shift_logits = logits[:, :-1, :]
shift_labels = input_ids[:, 1:]
if attention_mask is None:
shift_mask = torch.ones_like(shift_labels, dtype=hidden.dtype)
else:
shift_mask = attention_mask[:, 1:].to(hidden.dtype)
window = max(1, min(int(future_window), shift_labels.size(1)))
target_logits = shift_logits[:, -window:, :]
target_labels = shift_labels[:, -window:]
target_mask = shift_mask[:, -window:]
token_loss = F.cross_entropy(
target_logits.reshape(-1, target_logits.size(-1)),
target_labels.reshape(-1),
reduction="none",
).view_as(target_labels)
masked_loss = token_loss * target_mask
loss = masked_loss.sum() / target_mask.sum().clamp_min(1.0)
grad_hidden = torch.autograd.grad(loss, hidden, retain_graph=False, create_graph=False)[0]
raw_scores = grad_hidden.norm(dim=-1)
denom = raw_scores.amax(dim=-1, keepdim=True).clamp_min(1e-6)
norm_scores = raw_scores / denom
if attention_mask is not None:
norm_scores = norm_scores * attention_mask.to(norm_scores.dtype)
raw_scores = raw_scores * attention_mask.to(raw_scores.dtype)
return {
"scores": norm_scores.detach(),
"raw_scores": raw_scores.detach(),
"loss": float(loss.detach().item()),
"target_window": window,
}
|