File size: 3,781 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import sys
sys.path.append("./BranchSBM")
import torch
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x
import torch.nn as nn
class BranchSBM(ConditionalFlowMatcher):
def __init__(
self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.alpha = alpha
self.geopath_nets = geopath_nets
if self.alpha != 0:
assert (
geopath_nets is not None
), "GeoPath model must be provided if alpha != 0"
self.branches = len(geopath_nets)
def gamma(self, t, t_min, t_max):
return (
1.0
- ((t - t_min) / (t_max - t_min)) ** 2
- ((t_max - t) / (t_max - t_min)) ** 2
)
def d_gamma(self, t, t_min, t_max):
return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2
def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx):
assert branch_idx < self.branches, "Index out of bounds"
with torch.enable_grad():
t = pad_t_like_x(t, x0)
if self.alpha == 0:
return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / (
t_max - t_min
) * x1
# compute value for specific branch
self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t)
if self.geopath_nets[branch_idx].time_geopath:
self.doutput_dt = torch.autograd.grad(
self.geopath_net_output,
t,
grad_outputs=torch.ones_like(self.geopath_net_output),
create_graph=False,
retain_graph=True,
)[0]
return (
(t_max - t) / (t_max - t_min) * x0
+ (t - t_min) / (t_max - t_min) * x1
+ self.gamma(t, t_min, t_max) * self.geopath_net_output
)
def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx):
assert branch_idx < self.branches, "Index out of bounds"
mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx)
sigma_t = self.compute_sigma_t(t)
sigma_t = pad_t_like_x(sigma_t, x0)
return mu_t + sigma_t * epsilon
def sample_location_and_conditional_flow(
self,
x0,
x1,
t_min,
t_max,
branch_idx,
training_geopath_net=False,
midpoint_only=False,
t=None,
):
self.training_geopath_net = training_geopath_net
with torch.enable_grad():
if t is None:
t = torch.rand(x0.shape[0], requires_grad=True)
t = t.type_as(x0)
t = t * (t_max - t_min) + t_min
if midpoint_only:
t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"
eps = self.sample_noise_like(x0)
# compute xt and ut for branch_idx
xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx)
ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx)
return t, xt, ut
def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx):
del xt
t = pad_t_like_x(t, x0)
if self.alpha == 0:
return (x1 - x0) / (t_max - t_min)
return (
(x1 - x0) / (t_max - t_min)
+ self.d_gamma(t, t_min, t_max) * self.geopath_net_output
+ (
self.gamma(t, t_min, t_max) * self.doutput_dt
if self.geopath_nets[branch_idx].time_geopath
else 0
)
) |