File size: 2,612 Bytes
c9311b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from abc import ABC, abstractmethod

import torch

from diff_eq.ode_sde import ODE, SDE


class Simulator(ABC):
    """
    Abstract base class for simulators.
    """
    @abstractmethod
    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Completes one step of the simulation.
        :param xt: state at time t, shape (bs, c, h, w)
        :param t: time, shape (bs, 1, 1, 1)
        :param dt: time change, shape(bs, 1, 1, 1)
        :return: nxt: state at time t + dt, shape (bs, c, h, w)
        """
        pass

    @torch.no_grad()
    def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs):
        """
        Simulate using discretization given by ts and yield intermediate results.
        :param x: initial state, shape(bs, c, h, w)
        :param ts: timesteps, shape (bs, nts, 1, 1, 1)
        :yield: state at each timestep, shape(bs, c, h, w)
        """
        nts = ts.shape[1]

        for t_idx in range(nts - 1):
            t = ts[:, t_idx]
            h = ts[:, t_idx + 1] - ts[:, t_idx]
            x = self.step(x, t, h, **kwargs)
            yield x  # yield the updated state at this timestep

    @torch.no_grad()
    def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Simulate with trajectory using discretization given by ts.
        :param x: initial state, shape(bs, c, h, w)
        :param ts: timesteps, shape (bs, nts, 1, 1, 1)
        :return: trajectory of xts over ts, shape(bs, c, h, w)
        """
        xs = [x.clone()]
        nts = ts.shape[1]

        for t_idx in range(nts - 1):
            t = ts[:, t_idx]
            h = ts[:, t_idx + 1] - ts[:, t_idx]
            x = self.step(x, t, h, **kwargs)
            xs.append(x)

        return torch.stack(xs, dim=1)


class EulerSimulator(Simulator):
    """
    Simulates an ODE using Euler method.
    """
    def __init__(self, ode: ODE):
        self.ode = ode

    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor:
        return xt + self.ode.drift_coefficient(xt, t, **kwargs) * dt


class EulerMaruyamaSimulator(Simulator):
    """
    Simulates an SDE using Euler-Maruyama method.
    """
    def __init__(self, sde: SDE):
        self.sde = sde

    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor:
        return xt * self.sde.drift_coefficient(xt, t, **kwargs) * dt + \
            self.sde.diffusion_coefficient(xt, t, **kwargs) * torch.sqrt(dt) * torch.rand_like(xt)