| """pivot training losses: |
| |
| L_map = E‖X_θ(s,t,c_s,e_u) − c_t‖² |
| L_tan = E‖v_θ(τ,τ,c_τ,e_u) − (c1−c0)‖² |
| L_semi = E‖X_θ(s,t,c_s,e_u) − X_θ(r,t, X_θ(s,r,c_s,e_u), e_u)‖² |
| L_reg = ‖e_u‖² |
| L = L_map + λ_tan·L_tan + λ_semi·L_semi + λ_reg·L_reg |
| |
| c_s, c_t, c_τ are linear interpolants between matched control c0 and perturbed c1. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def _sorted_uniform(n, k, device): |
| """return k columns of sorted U(0,1) samples: shape (n, k), each row ascending.""" |
| u = torch.rand(n, k, device=device) |
| return torch.sort(u, dim=1).values |
|
|
|
|
| def interp(c0, c1, frac): |
| """(1-frac)·c0 + frac·c1, frac shape (b,).""" |
| f = frac.view(-1, 1) |
| return (1 - f) * c0 + f * c1 |
|
|
|
|
| def compute_losses(flow, e_u, c0, c1, lambdas: dict): |
| """return (total, dict_of_components). shapes: c0,c1,(b,d); e_u (b,m).""" |
| B = c0.shape[0] |
| dev = c0.device |
|
|
| |
| st = _sorted_uniform(B, 2, dev) |
| s, t = st[:, 0], st[:, 1] |
| c_s = interp(c0, c1, s) |
| c_t = interp(c0, c1, t) |
| pred_t = flow(s, t, c_s, e_u) |
| L_map = ((pred_t - c_t) ** 2).sum(-1).mean() |
|
|
| |
| tau = torch.rand(B, device=dev) |
| c_tau = interp(c0, c1, tau) |
| v_diag = flow.velocity(tau, tau, c_tau, e_u) |
| L_tan = ((v_diag - (c1 - c0)) ** 2).sum(-1).mean() |
|
|
| |
| srt = _sorted_uniform(B, 3, dev) |
| s2, r2, t2 = srt[:, 0], srt[:, 1], srt[:, 2] |
| direct = flow(s2, t2, c_s, e_u) |
| mid = flow(s2, r2, c_s, e_u) |
| composed = flow(r2, t2, mid, e_u) |
| L_semi = ((direct - composed) ** 2).sum(-1).mean() |
|
|
| |
| L_reg = (e_u ** 2).sum(-1).mean() |
|
|
| total = ( |
| lambdas.get("map", 1.0) * L_map |
| + lambdas.get("tan", 1.0) * L_tan |
| + lambdas.get("semi", 1.0) * L_semi |
| + lambdas.get("reg", 1e-4) * L_reg |
| ) |
| return total, { |
| "map": L_map.item(), |
| "tan": L_tan.item(), |
| "semi": L_semi.item(), |
| "reg": L_reg.item(), |
| "total": total.item(), |
| } |
|
|