"""pivot training losses: L_map = E‖X_θ(s,t,c_s,e_u) − c_t‖² L_tan = E‖v_θ(τ,τ,c_τ,e_u) − (c1−c0)‖² L_semi = E‖X_θ(s,t,c_s,e_u) − X_θ(r,t, X_θ(s,r,c_s,e_u), e_u)‖² L_reg = ‖e_u‖² L = L_map + λ_tan·L_tan + λ_semi·L_semi + λ_reg·L_reg c_s, c_t, c_τ are linear interpolants between matched control c0 and perturbed c1. """ from __future__ import annotations import torch def _sorted_uniform(n, k, device): """return k columns of sorted U(0,1) samples: shape (n, k), each row ascending.""" u = torch.rand(n, k, device=device) return torch.sort(u, dim=1).values def interp(c0, c1, frac): """(1-frac)·c0 + frac·c1, frac shape (b,).""" f = frac.view(-1, 1) return (1 - f) * c0 + f * c1 def compute_losses(flow, e_u, c0, c1, lambdas: dict): """return (total, dict_of_components). shapes: c0,c1,(b,d); e_u (b,m).""" B = c0.shape[0] dev = c0.device # map loss: 0<=s<=t<=1 st = _sorted_uniform(B, 2, dev) s, t = st[:, 0], st[:, 1] c_s = interp(c0, c1, s) c_t = interp(c0, c1, t) pred_t = flow(s, t, c_s, e_u) L_map = ((pred_t - c_t) ** 2).sum(-1).mean() # tangent loss: diagonal velocity matches interpolation velocity (c1-c0) tau = torch.rand(B, device=dev) c_tau = interp(c0, c1, tau) v_diag = flow.velocity(tau, tau, c_tau, e_u) L_tan = ((v_diag - (c1 - c0)) ** 2).sum(-1).mean() # semigroup loss: 0<=s<=r<=t<=1, direct vs composed srt = _sorted_uniform(B, 3, dev) s2, r2, t2 = srt[:, 0], srt[:, 1], srt[:, 2] direct = flow(s2, t2, c_s, e_u) mid = flow(s2, r2, c_s, e_u) composed = flow(r2, t2, mid, e_u) L_semi = ((direct - composed) ** 2).sum(-1).mean() # regularizer L_reg = (e_u ** 2).sum(-1).mean() total = ( lambdas.get("map", 1.0) * L_map + lambdas.get("tan", 1.0) * L_tan + lambdas.get("semi", 1.0) * L_semi + lambdas.get("reg", 1e-4) * L_reg ) return total, { "map": L_map.item(), "tan": L_tan.item(), "semi": L_semi.item(), "reg": L_reg.item(), "total": total.item(), }