PIVOT / src /models /flow_map.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
2.57 kB
"""perturbation-conditioned flow map X_θ.
X_θ(s, t, c_s, e_u) = c_s + (t - s) · v_θ(s, t, c_s, e_u), 0 ≤ s ≤ t ≤ 1
ĉ1(u) = X_θ(0, 1, c0, e_u)
v_θ is an mlp over [γ(s), γ(t), c_s, e_u], γ is a fourier time embedding.
the map is differentiable in e_u, which is what enables the inverse-design
jacobian ∇_{e_u} X_{0,1} used by reward guidance (Alg 4).
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
class FourierTime(nn.Module):
"""map a scalar t∈[0,1] to [t, sin(2πk t), cos(2πk t) ...]."""
def __init__(self, n_freq: int = 6):
super().__init__()
self.register_buffer("freqs", 2 * math.pi * (2.0 ** torch.arange(n_freq)).float())
def forward(self, t: torch.Tensor) -> torch.Tensor:
t = t.view(-1, 1)
ang = t * self.freqs.view(1, -1)
return torch.cat([t, torch.sin(ang), torch.cos(ang)], dim=-1)
class FlowMap(nn.Module):
def __init__(self, d_state: int, d_pert: int, hidden: int = 512, depth: int = 4,
n_freq: int = 6, dropout: float = 0.0):
super().__init__()
self.d_state = d_state
self.d_pert = d_pert
self.time = FourierTime(n_freq)
t_dim = 1 + 2 * n_freq
in_dim = 2 * t_dim + d_state + d_pert
layers = [nn.Linear(in_dim, hidden), nn.SiLU()]
for _ in range(depth - 1):
layers += [nn.Linear(hidden, hidden), nn.SiLU()]
if dropout > 0:
layers += [nn.Dropout(dropout)]
layers += [nn.Linear(hidden, d_state)]
self.net = nn.Sequential(*layers)
# zero-init final layer => v≈0 at init => X≈identity (stable warm start)
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def velocity(self, s, t, c_s, e_u):
s = s.expand(c_s.shape[0]) if s.dim() == 0 else s
t = t.expand(c_s.shape[0]) if t.dim() == 0 else t
h = torch.cat([self.time(s), self.time(t), c_s, e_u], dim=-1)
return self.net(h)
def forward(self, s, t, c_s, e_u):
"""X_θ(s,t,c_s,e_u). s,t are scalars or (b,) tensors in [0,1]."""
v = self.velocity(s, t, c_s, e_u)
dt = (t - s)
dt = dt.view(-1, 1) if torch.is_tensor(dt) and dt.dim() >= 1 else dt
return c_s + dt * v
def endpoint(self, c0, e_u):
"""ĉ1 = X_θ(0,1,c0,e_u)."""
zeros = torch.zeros(c0.shape[0], device=c0.device)
ones = torch.ones(c0.shape[0], device=c0.device)
return self.forward(zeros, ones, c0, e_u)