File size: 3,781 Bytes
5a87d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
sys.path.append("./BranchSBM")
import torch
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x
import torch.nn as nn

class BranchSBM(ConditionalFlowMatcher):
    def __init__(
        self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.geopath_nets = geopath_nets
        if self.alpha != 0:
            assert (
                geopath_nets is not None
            ), "GeoPath model must be provided if alpha != 0"
            
        self.branches = len(geopath_nets)

    def gamma(self, t, t_min, t_max):
        return (
            1.0
            - ((t - t_min) / (t_max - t_min)) ** 2
            - ((t_max - t) / (t_max - t_min)) ** 2
        )

    def d_gamma(self, t, t_min, t_max):
        return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2

    def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx):
        assert branch_idx < self.branches, "Index out of bounds"

        with torch.enable_grad():
            t = pad_t_like_x(t, x0)
            if self.alpha == 0:
                return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / (
                    t_max - t_min
                ) * x1
                
            # compute value for specific branch
            self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t)
            if self.geopath_nets[branch_idx].time_geopath:
                self.doutput_dt = torch.autograd.grad(
                    self.geopath_net_output,
                    t,
                    grad_outputs=torch.ones_like(self.geopath_net_output),
                    create_graph=False,
                    retain_graph=True,
                )[0]
        return (
            (t_max - t) / (t_max - t_min) * x0
            + (t - t_min) / (t_max - t_min) * x1
            + self.gamma(t, t_min, t_max) * self.geopath_net_output
        )

    def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx):
        assert branch_idx < self.branches, "Index out of bounds"
        mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx)
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        return mu_t + sigma_t * epsilon

    def sample_location_and_conditional_flow(
        self,
        x0,
        x1,
        t_min,
        t_max,
        branch_idx,
        training_geopath_net=False,
        midpoint_only=False,
        t=None,
    ):

        self.training_geopath_net = training_geopath_net
        with torch.enable_grad():
            if t is None:
                t = torch.rand(x0.shape[0], requires_grad=True)
            t = t.type_as(x0)
            t = t * (t_max - t_min) + t_min
            if midpoint_only:
                t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0)
        
        assert len(t) == x0.shape[0], "t has to have batch size dimension"

        eps = self.sample_noise_like(x0)
        
        # compute xt and ut for branch_idx
        xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx)
        ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx)

        return t, xt, ut

    def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx):
        del xt
        t = pad_t_like_x(t, x0)
        if self.alpha == 0:
            return (x1 - x0) / (t_max - t_min)
        
        return (
            (x1 - x0) / (t_max - t_min)
            + self.d_gamma(t, t_min, t_max) * self.geopath_net_output
            + (
                self.gamma(t, t_min, t_max) * self.doutput_dt
                if self.geopath_nets[branch_idx].time_geopath
                else 0
            )
        )