PIVOT / src /training /losses.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
2.17 kB
"""pivot training losses:
L_map = E‖X_θ(s,t,c_s,e_u) − c_t‖²
L_tan = E‖v_θ(τ,τ,c_τ,e_u) − (c1−c0)‖²
L_semi = E‖X_θ(s,t,c_s,e_u) − X_θ(r,t, X_θ(s,r,c_s,e_u), e_u)‖²
L_reg = ‖e_u‖²
L = L_map + λ_tan·L_tan + λ_semi·L_semi + λ_reg·L_reg
c_s, c_t, c_τ are linear interpolants between matched control c0 and perturbed c1.
"""
from __future__ import annotations
import torch
def _sorted_uniform(n, k, device):
"""return k columns of sorted U(0,1) samples: shape (n, k), each row ascending."""
u = torch.rand(n, k, device=device)
return torch.sort(u, dim=1).values
def interp(c0, c1, frac):
"""(1-frac)·c0 + frac·c1, frac shape (b,)."""
f = frac.view(-1, 1)
return (1 - f) * c0 + f * c1
def compute_losses(flow, e_u, c0, c1, lambdas: dict):
"""return (total, dict_of_components). shapes: c0,c1,(b,d); e_u (b,m)."""
B = c0.shape[0]
dev = c0.device
# map loss: 0<=s<=t<=1
st = _sorted_uniform(B, 2, dev)
s, t = st[:, 0], st[:, 1]
c_s = interp(c0, c1, s)
c_t = interp(c0, c1, t)
pred_t = flow(s, t, c_s, e_u)
L_map = ((pred_t - c_t) ** 2).sum(-1).mean()
# tangent loss: diagonal velocity matches interpolation velocity (c1-c0)
tau = torch.rand(B, device=dev)
c_tau = interp(c0, c1, tau)
v_diag = flow.velocity(tau, tau, c_tau, e_u)
L_tan = ((v_diag - (c1 - c0)) ** 2).sum(-1).mean()
# semigroup loss: 0<=s<=r<=t<=1, direct vs composed
srt = _sorted_uniform(B, 3, dev)
s2, r2, t2 = srt[:, 0], srt[:, 1], srt[:, 2]
direct = flow(s2, t2, c_s, e_u)
mid = flow(s2, r2, c_s, e_u)
composed = flow(r2, t2, mid, e_u)
L_semi = ((direct - composed) ** 2).sum(-1).mean()
# regularizer
L_reg = (e_u ** 2).sum(-1).mean()
total = (
lambdas.get("map", 1.0) * L_map
+ lambdas.get("tan", 1.0) * L_tan
+ lambdas.get("semi", 1.0) * L_semi
+ lambdas.get("reg", 1e-4) * L_reg
)
return total, {
"map": L_map.item(),
"tan": L_tan.item(),
"semi": L_semi.item(),
"reg": L_reg.item(),
"total": total.item(),
}