Spaces:
Running
Running
File size: 1,505 Bytes
c9311b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from abc import ABC, abstractmethod
from typing import List, Iterator
import torch
from torch import nn
class Sampleable(ABC):
"""
Distribution to sample from.
"""
@abstractmethod
def sample(self, num_samples: int, **kwargs) -> torch.Tensor:
"""
:param num_samples
:return: samples: shape (batch_size, ...)
"""
pass
class IterableSampleable(Sampleable, ABC):
"""
Sampleable for finite datasets.
"""
@abstractmethod
def iterate_dataset(self, batch_size: int, mode: str = 'val') -> Iterator[torch.Tensor]:
"""
Iterates over the entire dataset (val/test) in mini-batches of size `batch_size`.
:param batch_size: number of images per batch
:param mode: 'train', 'val', or 'test'
:return: yields batches as torch tensors
"""
pass
class IsotropicGaussian(nn.Module, Sampleable):
def __init__(self, shape: List[int], std: float = 1.0):
"""
:param shape: shape of sampled data, e.g. [4, 128, 128]
:param std: standard deviation for sampling
"""
super().__init__()
self.shape = shape
self.std = std
self.dummy = nn.Buffer(torch.zeros(1)) # Will automatically be moved when self.to(...) is called
def sample(self, num_samples: int, **kwargs) -> torch.Tensor:
device = self.dummy.device
C, H, W = self.shape
return self.std * torch.randn(num_samples, C, H, W, device=device) |