| | import sys |
| | import torch |
| | from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x |
| | import torch.nn as nn |
| |
|
| | class BranchSBM(ConditionalFlowMatcher): |
| | def __init__( |
| | self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs |
| | ): |
| | super().__init__(*args, **kwargs) |
| | self.alpha = alpha |
| | self.geopath_nets = geopath_nets |
| | if self.alpha != 0: |
| | assert ( |
| | geopath_nets is not None |
| | ), "GeoPath model must be provided if alpha != 0" |
| | |
| | self.branches = len(geopath_nets) |
| |
|
| | def gamma(self, t, t_min, t_max): |
| | return ( |
| | 1.0 |
| | - ((t - t_min) / (t_max - t_min)) ** 2 |
| | - ((t_max - t) / (t_max - t_min)) ** 2 |
| | ) |
| |
|
| | def d_gamma(self, t, t_min, t_max): |
| | return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2 |
| |
|
| | def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx): |
| | assert branch_idx < self.branches, "Index out of bounds" |
| |
|
| | with torch.enable_grad(): |
| | t = pad_t_like_x(t, x0) |
| | if self.alpha == 0: |
| | return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / ( |
| | t_max - t_min |
| | ) * x1 |
| | |
| | |
| | self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t) |
| | if self.geopath_nets[branch_idx].time_geopath: |
| | self.doutput_dt = torch.autograd.grad( |
| | self.geopath_net_output, |
| | t, |
| | grad_outputs=torch.ones_like(self.geopath_net_output), |
| | create_graph=False, |
| | retain_graph=True, |
| | )[0] |
| | return ( |
| | (t_max - t) / (t_max - t_min) * x0 |
| | + (t - t_min) / (t_max - t_min) * x1 |
| | + self.gamma(t, t_min, t_max) * self.geopath_net_output |
| | ) |
| |
|
| | def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx): |
| | assert branch_idx < self.branches, "Index out of bounds" |
| | mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx) |
| | sigma_t = self.compute_sigma_t(t) |
| | sigma_t = pad_t_like_x(sigma_t, x0) |
| | return mu_t + sigma_t * epsilon |
| |
|
| | def sample_location_and_conditional_flow( |
| | self, |
| | x0, |
| | x1, |
| | t_min, |
| | t_max, |
| | branch_idx, |
| | training_geopath_net=False, |
| | midpoint_only=False, |
| | t=None, |
| | ): |
| |
|
| | self.training_geopath_net = training_geopath_net |
| | with torch.enable_grad(): |
| | if t is None: |
| | t = torch.rand(x0.shape[0], requires_grad=True) |
| | t = t.type_as(x0) |
| | t = t * (t_max - t_min) + t_min |
| | if midpoint_only: |
| | t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0) |
| | |
| | assert len(t) == x0.shape[0], "t has to have batch size dimension" |
| |
|
| | eps = self.sample_noise_like(x0) |
| | |
| | |
| | xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx) |
| | ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx) |
| |
|
| | return t, xt, ut |
| |
|
| | def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx): |
| | del xt |
| | t = pad_t_like_x(t, x0) |
| | if self.alpha == 0: |
| | return (x1 - x0) / (t_max - t_min) |
| | |
| | return ( |
| | (x1 - x0) / (t_max - t_min) |
| | + self.d_gamma(t, t_min, t_max) * self.geopath_net_output |
| | + ( |
| | self.gamma(t, t_min, t_max) * self.doutput_dt |
| | if self.geopath_nets[branch_idx].time_geopath |
| | else 0 |
| | ) |
| | ) |