Spaces:
Running
Running
| from abc import ABC, abstractmethod | |
| import torch | |
| from models.conditional_vector_field import ConditionalVectorField | |
| class ODE(ABC): | |
| def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ | |
| Computes the drift coefficient of the ODE. | |
| :param xt: state at time t, shape (bs, c, h, w) | |
| :param t: time, shape (bs, 1) | |
| :return: drift coefficient, shape (bs, c, h, w) | |
| """ | |
| pass | |
| class SDE(ABC): | |
| def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ | |
| Computes the drift coefficient of the SDE. | |
| :param xt: state at time t, shape (bs, c, h, w) | |
| :param t: time, shape (bs, 1, 1, 1) | |
| :return: drift coefficient, shape (bs, c, h, w) | |
| """ | |
| pass | |
| def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ | |
| Returns the diffusion coefficient of the SDE. | |
| :param xt: state at time t, shape (bs, c, h, w) | |
| :param t: shape (bs, 1, 1, 1) | |
| :return: diffusion coefficient, shape (bs, c, h, w) | |
| """ | |
| pass | |
| class UnguidedVectorFieldODE(ODE): | |
| def __init__(self, net: ConditionalVectorField): | |
| self.net = net | |
| def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor: | |
| return self.net(xt, t) |