from abc import ABC, abstractmethod from typing import List, Iterator import torch from torch import nn class Sampleable(ABC): """ Distribution to sample from. """ @abstractmethod def sample(self, num_samples: int, **kwargs) -> torch.Tensor: """ :param num_samples :return: samples: shape (batch_size, ...) """ pass class IterableSampleable(Sampleable, ABC): """ Sampleable for finite datasets. """ @abstractmethod def iterate_dataset(self, batch_size: int, mode: str = 'val') -> Iterator[torch.Tensor]: """ Iterates over the entire dataset (val/test) in mini-batches of size `batch_size`. :param batch_size: number of images per batch :param mode: 'train', 'val', or 'test' :return: yields batches as torch tensors """ pass class IsotropicGaussian(nn.Module, Sampleable): def __init__(self, shape: List[int], std: float = 1.0): """ :param shape: shape of sampled data, e.g. [4, 128, 128] :param std: standard deviation for sampling """ super().__init__() self.shape = shape self.std = std self.dummy = nn.Buffer(torch.zeros(1)) # Will automatically be moved when self.to(...) is called def sample(self, num_samples: int, **kwargs) -> torch.Tensor: device = self.dummy.device C, H, W = self.shape return self.std * torch.randn(num_samples, C, H, W, device=device)