Spaces:
Running
Running
File size: 1,466 Bytes
c9311b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | from abc import ABC, abstractmethod
import torch
from models.conditional_vector_field import ConditionalVectorField
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Computes the drift coefficient of the ODE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: time, shape (bs, 1)
:return: drift coefficient, shape (bs, c, h, w)
"""
pass
class SDE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Computes the drift coefficient of the SDE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: time, shape (bs, 1, 1, 1)
:return: drift coefficient, shape (bs, c, h, w)
"""
pass
@abstractmethod
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Returns the diffusion coefficient of the SDE.
:param xt: state at time t, shape (bs, c, h, w)
:param t: shape (bs, 1, 1, 1)
:return: diffusion coefficient, shape (bs, c, h, w)
"""
pass
class UnguidedVectorFieldODE(ODE):
def __init__(self, net: ConditionalVectorField):
self.net = net
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(xt, t) |