Spaces:
Running
Running
| from abc import ABC, abstractmethod | |
| from typing import List, Iterator | |
| import torch | |
| from torch import nn | |
| class Sampleable(ABC): | |
| """ | |
| Distribution to sample from. | |
| """ | |
| def sample(self, num_samples: int, **kwargs) -> torch.Tensor: | |
| """ | |
| :param num_samples | |
| :return: samples: shape (batch_size, ...) | |
| """ | |
| pass | |
| class IterableSampleable(Sampleable, ABC): | |
| """ | |
| Sampleable for finite datasets. | |
| """ | |
| def iterate_dataset(self, batch_size: int, mode: str = 'val') -> Iterator[torch.Tensor]: | |
| """ | |
| Iterates over the entire dataset (val/test) in mini-batches of size `batch_size`. | |
| :param batch_size: number of images per batch | |
| :param mode: 'train', 'val', or 'test' | |
| :return: yields batches as torch tensors | |
| """ | |
| pass | |
| class IsotropicGaussian(nn.Module, Sampleable): | |
| def __init__(self, shape: List[int], std: float = 1.0): | |
| """ | |
| :param shape: shape of sampled data, e.g. [4, 128, 128] | |
| :param std: standard deviation for sampling | |
| """ | |
| super().__init__() | |
| self.shape = shape | |
| self.std = std | |
| self.dummy = nn.Buffer(torch.zeros(1)) # Will automatically be moved when self.to(...) is called | |
| def sample(self, num_samples: int, **kwargs) -> torch.Tensor: | |
| device = self.dummy.device | |
| C, H, W = self.shape | |
| return self.std * torch.randn(num_samples, C, H, W, device=device) |