File size: 1,466 Bytes
c9311b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from abc import ABC, abstractmethod

import torch

from models.conditional_vector_field import ConditionalVectorField


class ODE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Computes the drift coefficient of the ODE.
        :param xt: state at time t, shape (bs, c, h, w)
        :param t: time, shape (bs, 1)
        :return: drift coefficient, shape (bs, c, h, w)
        """
        pass


class SDE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Computes the drift coefficient of the SDE.
        :param xt: state at time t, shape (bs, c, h, w)
        :param t: time, shape (bs, 1, 1, 1)
        :return: drift coefficient, shape (bs, c, h, w)
        """
        pass

    @abstractmethod
    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Returns the diffusion coefficient of the SDE.
        :param xt: state at time t, shape (bs, c, h, w)
        :param t: shape (bs, 1, 1, 1)
        :return: diffusion coefficient, shape (bs, c, h, w)
        """
        pass


class UnguidedVectorFieldODE(ODE):
    def __init__(self, net: ConditionalVectorField):
        self.net = net

    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.net(xt, t)