| """perturbation-conditioned flow map X_θ. |
| |
| X_θ(s, t, c_s, e_u) = c_s + (t - s) · v_θ(s, t, c_s, e_u), 0 ≤ s ≤ t ≤ 1 |
| ĉ1(u) = X_θ(0, 1, c0, e_u) |
| |
| v_θ is an mlp over [γ(s), γ(t), c_s, e_u], γ is a fourier time embedding. |
| the map is differentiable in e_u, which is what enables the inverse-design |
| jacobian ∇_{e_u} X_{0,1} used by reward guidance (Alg 4). |
| """ |
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class FourierTime(nn.Module): |
| """map a scalar t∈[0,1] to [t, sin(2πk t), cos(2πk t) ...].""" |
|
|
| def __init__(self, n_freq: int = 6): |
| super().__init__() |
| self.register_buffer("freqs", 2 * math.pi * (2.0 ** torch.arange(n_freq)).float()) |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| t = t.view(-1, 1) |
| ang = t * self.freqs.view(1, -1) |
| return torch.cat([t, torch.sin(ang), torch.cos(ang)], dim=-1) |
|
|
|
|
| class FlowMap(nn.Module): |
| def __init__(self, d_state: int, d_pert: int, hidden: int = 512, depth: int = 4, |
| n_freq: int = 6, dropout: float = 0.0): |
| super().__init__() |
| self.d_state = d_state |
| self.d_pert = d_pert |
| self.time = FourierTime(n_freq) |
| t_dim = 1 + 2 * n_freq |
| in_dim = 2 * t_dim + d_state + d_pert |
| layers = [nn.Linear(in_dim, hidden), nn.SiLU()] |
| for _ in range(depth - 1): |
| layers += [nn.Linear(hidden, hidden), nn.SiLU()] |
| if dropout > 0: |
| layers += [nn.Dropout(dropout)] |
| layers += [nn.Linear(hidden, d_state)] |
| self.net = nn.Sequential(*layers) |
| |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| def velocity(self, s, t, c_s, e_u): |
| s = s.expand(c_s.shape[0]) if s.dim() == 0 else s |
| t = t.expand(c_s.shape[0]) if t.dim() == 0 else t |
| h = torch.cat([self.time(s), self.time(t), c_s, e_u], dim=-1) |
| return self.net(h) |
|
|
| def forward(self, s, t, c_s, e_u): |
| """X_θ(s,t,c_s,e_u). s,t are scalars or (b,) tensors in [0,1].""" |
| v = self.velocity(s, t, c_s, e_u) |
| dt = (t - s) |
| dt = dt.view(-1, 1) if torch.is_tensor(dt) and dt.dim() >= 1 else dt |
| return c_s + dt * v |
|
|
| def endpoint(self, c0, e_u): |
| """ĉ1 = X_θ(0,1,c0,e_u).""" |
| zeros = torch.zeros(c0.shape[0], device=c0.device) |
| ones = torch.ones(c0.shape[0], device=c0.device) |
| return self.forward(zeros, ones, c0, e_u) |
|
|