File size: 9,800 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import abc
import torch
from ShapeID.DiffEqs.misc import _assert_increasing, _handle_unused_kwargs

def set_BC_2D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
    BC_size = BCs.size(2)
    X[:, : BC_size] = BCs[:, 0]
    X[:, - BC_size :] = BCs[:, 1]
    X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
    X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
    del BCs
    return X
def set_BC_3D(X, BCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
    BC_size = BCs.size(2)
    X[:, : BC_size] = BCs[:, 0]
    X[:, - BC_size :] = BCs[:, 1]
    X[:, :, : BC_size] = BCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
    X[:, :, - BC_size :] = BCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
    X[:, :, :, : BC_size] = BCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
    X[:, :, :, - BC_size :] = BCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
    del BCs
    return X

''' X[t] = X[t] + dBC[t] (dBC[t] = BC[t+1] - BC[t]) '''
def add_dBC_2D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 4, BC_shape, data_dim)
    BC_size = dBCs.size(2)
    X[:, : BC_size] += dBCs[:, 0]
    X[:, - BC_size :] += dBCs[:, 1]
    X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
    X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1) # (batch, BC_shape, r) -> (batch, r, BC_shape)
    del dBCs
    return X
def add_dBC_3D(X, dBCs): # X: (n_batch, spatial_size); BCs: (batch, 6, BC_shape, data_dim, dta_dim)
    BC_size = dBCs.size(2)
    X[:, : BC_size] += dBCs[:, 0]
    X[:, - BC_size :] += dBCs[:, 1]
    X[:, :, : BC_size] += dBCs[:, 2].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
    X[:, :, - BC_size :] += dBCs[:, 3].permute(0, 2, 1, 3) # (batch, BC_shape, s, c) -> (batch, s, BC_shape, c)
    X[:, :, :, : BC_size] += dBCs[:, 4].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
    X[:, :, :, - BC_size :] += dBCs[:, 5].permute(0, 2, 3, 1) # (batch, BC_shape, s, r) -> (batch, s, r, BC_shape)
    del dBCs
    return X

class AdaptiveStepsizeODESolver(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, func, y0, atol, rtol, options= None):
        
       # _handle_unused_kwargs(self, options)
        #del options
        self.func = func
        self.y0 = y0
        self.atol = atol
        self.rtol = rtol

    def before_integrate(self, t):
        pass

    @abc.abstractmethod
    def advance(self, next_t):
        raise NotImplementedError

    def integrate(self, t):
        _assert_increasing(t)
        solution = [self.y0]
        t = t.to(self.y0[0].device, torch.float64)
        self.before_integrate(t)
        for i in range(1, len(t)):
            y = self.advance(t[i])
            solution.append(y)
        '''if self.contours is not None: # contours: (n_batch, nT, 4 / 6, BC_size, c)
            if self.adjoint:
                for i in range(1, len(t)):
                    ys = list(self.advance(t[i])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
                    #print(len(t))
                    #print(ys[0].size())
                    #print(self.contours.size())
                    ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
                    solution.append(tuple(ys))
            else:
                for i in range(1, len(t)):
                    y = torch.stack(self.advance(t[i])) # y: (n_batch, 1, spatial_shape) 
                    y = self.set_BC(y[:, 0], self.contours[:, i]).unsqueeze(1) 
                    solution.append(tuple(y))
        elif self.dcontours is not None: # dcontours: (n_batch, nT, 4 / 6, BC_size, c)
            if self.adjoint:
                for i in range(1, len(t)):
                    ys = list(self.advance(t[i])) # ys - tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
                    ys[0] = self.add_dBC(ys[0], self.dcontours[:, i]) # (n_batch, 4 / 6, BC_size, c)
                    solution.append(tuple(ys))
            else:
                for i in range(1, len(t)):
                    y = torch.stack(self.advance(t[i])) # (n_batch, 1, spatial_shape)
                    y = self.add_dBC(y[:, 0], self.dcontours[:, i]).unsqueeze(1)
                    solution.append(tuple(y))
        else:
            for i in range(1, len(t)):
                y = self.advance(t[i])
                solution.append(y)'''
        return tuple(map(torch.stack, tuple(zip(*solution))))


class FixedGridODESolver(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, func, y0, step_size=None, grid_constructor=None, atol=None, rtol=None, dt=None, options = None):
        '''if 'dirichlet' in options.BC or 'cauchy' in options.BC and options.contours is not None:
            self.contours = options.contours  # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
            self.BC_size = self.contours.size(3)
            self.set_BC = set_BC_2D if self.contours.size(2) == 4 else set_BC_3D
        else:
            self.contours = None
        if 'source' in options.BC and options.dcontours is not None:
            self.dcontours = options.dcontours  # (n_batch, nT, 4 / 6, BC_size, sub_spatial_shape)
            self.BC_size = self.dcontours.size(3)
            self.add_dBC = add_dBC_2D if self.dcontours.size(2) == 4 else add_dBC_3D
        else:
            self.dcontours = None'''
        #self.adjoint = options.adjoint
        #options.pop('rtol', None)
        #options.pop('atol', None)
        #_handle_unused_kwargs(self, options)
        #del options

        self.func = func
        self.y0 = y0

        if step_size is not None and grid_constructor is None:
            self.grid_constructor = self._grid_constructor_from_step_size(step_size)
        elif grid_constructor is None:
            self.grid_constructor = lambda f, y0, t: t # Same time step as time interval
        else:
            raise ValueError("step_size and grid_constructor are exclusive arguments.")

    def _grid_constructor_from_step_size(self, step_size):

        def _grid_constructor(func, y0, t):
            start_time = t[0]
            end_time = t[-1]

            niters = torch.ceil((end_time - start_time) / step_size + 1).item()
            t_infer = torch.arange(0, niters).to(t) * step_size + start_time
            if t_infer[-1] > t[-1]:
                t_infer[-1] = t[-1]
            return t_infer

        return _grid_constructor

    @property
    @abc.abstractmethod
    def order(self):
        pass

    @abc.abstractmethod
    def step_func(self, func, t, dt, y):
        pass

    def integrate(self, t):
        _assert_increasing(t)
        t = t.type_as(self.y0[0]) # (n_time, )
        time_grid = self.grid_constructor(self.func, self.y0, t)
        #print('time_grid:', time_grid.size())
        #print('t:', t.size())
        assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
        time_grid = time_grid.to(self.y0[0])

        solution = [self.y0]

        j = 1
        y0 = self.y0
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0)
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
            y0 = y1
            while j < len(t) and t1 >= t[j]:
                solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
                j += 1
            '''if self.contours is not None:
                if self.adjoint:
                    for i in range(1, len(t)):
                        ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
                        ys[0] = self.set_BC(ys[0], self.contours[:, i]) # (n_batch, 4 / 6, BC_size, c)
                        solution.append(tuple(ys))
                        j += 1
                else:
                    while j < len(t) and t1 >= t[j]:
                        y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
                        y = self.set_BC(y[:, 0], self.contours[:, j]).unsqueeze(1) 
                        solution.append(tuple(y))
                        j += 1
            elif self.dcontours is not None:
                if self.adjoint:
                    for i in range(1, len(t)):
                        ys = list(self._linear_interp(t0, t1, y0, y1, t[j])) # tuple: (y0, **back_grad) -> y0: (n_batch, spatial_shape)
                        ys[0] = self.add_dBC(ys[0], self.dcontours[:, j]) # (n_batch, 4 / 6, BC_size, c)
                        solution.append(tuple(ys))
                else:
                    while j < len(t) and t1 >= t[j]:
                        y = torch.stack(self._linear_interp(t0, t1, y0, y1, t[j])) # (n_batch, 1, spatial_shape)
                        y = self.add_dBC(y[:, 0], self.dcontours[:, j]).unsqueeze(1)
                        solution.append(tuple(y))
                        j += 1
            else:
                while j < len(t) and t1 >= t[j]:
                    solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
                    j += 1'''
        return tuple(map(torch.stack, tuple(zip(*solution)))) # (batch, time)

    def _linear_interp(self, t0, t1, y0, y1, t):
        if t == t0:
            return y0
        if t == t1:
            return y1
        t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
        slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
        return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))