File size: 1,505 Bytes
c9311b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from abc import ABC, abstractmethod
from typing import List, Iterator

import torch
from torch import nn


class Sampleable(ABC):
    """
    Distribution to sample from.
    """
    @abstractmethod
    def sample(self, num_samples: int, **kwargs) -> torch.Tensor:
        """
        :param num_samples
        :return: samples: shape (batch_size, ...)
        """
        pass


class IterableSampleable(Sampleable, ABC):
    """
    Sampleable for finite datasets.
    """
    @abstractmethod
    def iterate_dataset(self, batch_size: int, mode: str = 'val') -> Iterator[torch.Tensor]:
        """
        Iterates over the entire dataset (val/test) in mini-batches of size `batch_size`.
        :param batch_size: number of images per batch
        :param mode: 'train', 'val', or 'test'
        :return: yields batches as torch tensors
        """
        pass


class IsotropicGaussian(nn.Module, Sampleable):
    def __init__(self, shape: List[int], std: float = 1.0):
        """
        :param shape: shape of sampled data, e.g. [4, 128, 128]
        :param std: standard deviation for sampling
        """
        super().__init__()
        self.shape = shape
        self.std = std

        self.dummy = nn.Buffer(torch.zeros(1))  # Will automatically be moved when self.to(...) is called

    def sample(self, num_samples: int, **kwargs) -> torch.Tensor:
        device = self.dummy.device
        C, H, W = self.shape
        return self.std * torch.randn(num_samples, C, H, W, device=device)