| from typing import Union, List |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Callable |
| from src.diffusion.base.scheduling import BaseScheduler |
|
|
| class BaseSampler(nn.Module): |
| def __init__(self, |
| scheduler: BaseScheduler = None, |
| guidance_fn: Callable = None, |
| num_steps: int = 250, |
| guidance: Union[float, List[float]] = 1.0, |
| *args, |
| **kwargs |
| ): |
| super(BaseSampler, self).__init__() |
| self.num_steps = num_steps |
| self.guidance = guidance |
| self.guidance_fn = guidance_fn |
| self.scheduler = scheduler |
|
|
| |
| def _impl_sampling(self, net, noise, condition, uncondition): |
| raise NotImplementedError |
|
|
| def forward(self, net, noise, condition, uncondition, return_x_trajs=False, return_v_trajs=False): |
| x_trajs, v_trajs = self._impl_sampling(net, noise, condition, uncondition) |
| if return_x_trajs and return_v_trajs: |
| return x_trajs[-1], x_trajs, v_trajs |
| elif return_x_trajs: |
| return x_trajs[-1], x_trajs |
| elif return_v_trajs: |
| return x_trajs[-1], v_trajs |
| else: |
| return x_trajs[-1] |
|
|
|
|
|
|