sprite-flow / diff_eq /ode_sde.py
mradovic38's picture
Upload app
c9311b7
from abc import ABC, abstractmethod
import torch
from models.conditional_vector_field import ConditionalVectorField
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Computes the drift coefficient of the ODE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: time, shape (bs, 1)
:return: drift coefficient, shape (bs, c, h, w)
"""
pass
class SDE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Computes the drift coefficient of the SDE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: time, shape (bs, 1, 1, 1)
:return: drift coefficient, shape (bs, c, h, w)
"""
pass
@abstractmethod
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the diffusion coefficient of the SDE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: shape (bs, 1, 1, 1)
:return: diffusion coefficient, shape (bs, c, h, w)
"""
pass
class UnguidedVectorFieldODE(ODE):
def __init__(self, net: ConditionalVectorField):
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t)