from abc import ABC, abstractmethod import torch from diff_eq.ode_sde import ODE, SDE class Simulator(ABC): """ Abstract base class for simulators. """ @abstractmethod def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor: """ Completes one step of the simulation. :param xt: state at time t, shape (bs, c, h, w) :param t: time, shape (bs, 1, 1, 1) :param dt: time change, shape(bs, 1, 1, 1) :return: nxt: state at time t + dt, shape (bs, c, h, w) """ pass @torch.no_grad() def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs): """ Simulate using discretization given by ts and yield intermediate results. :param x: initial state, shape(bs, c, h, w) :param ts: timesteps, shape (bs, nts, 1, 1, 1) :yield: state at each timestep, shape(bs, c, h, w) """ nts = ts.shape[1] for t_idx in range(nts - 1): t = ts[:, t_idx] h = ts[:, t_idx + 1] - ts[:, t_idx] x = self.step(x, t, h, **kwargs) yield x # yield the updated state at this timestep @torch.no_grad() def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor, **kwargs) -> torch.Tensor: """ Simulate with trajectory using discretization given by ts. :param x: initial state, shape(bs, c, h, w) :param ts: timesteps, shape (bs, nts, 1, 1, 1) :return: trajectory of xts over ts, shape(bs, c, h, w) """ xs = [x.clone()] nts = ts.shape[1] for t_idx in range(nts - 1): t = ts[:, t_idx] h = ts[:, t_idx + 1] - ts[:, t_idx] x = self.step(x, t, h, **kwargs) xs.append(x) return torch.stack(xs, dim=1) class EulerSimulator(Simulator): """ Simulates an ODE using Euler method. """ def __init__(self, ode: ODE): self.ode = ode def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor: return xt + self.ode.drift_coefficient(xt, t, **kwargs) * dt class EulerMaruyamaSimulator(Simulator): """ Simulates an SDE using Euler-Maruyama method. """ def __init__(self, sde: SDE): self.sde = sde def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs) -> torch.Tensor: return xt * self.sde.drift_coefficient(xt, t, **kwargs) * dt + \ self.sde.diffusion_coefficient(xt, t, **kwargs) * torch.sqrt(dt) * torch.rand_like(xt)