| """Container for latent space posterior.""" | |
| import torch | |
| class LatentDistribution: | |
| def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): | |
| """Initialize latent distribution. | |
| Args: | |
| mean: Mean of the distribution. Shape: [B, C, T, H, W]. | |
| logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W]. | |
| """ | |
| assert mean.shape == logvar.shape | |
| self.mean = mean | |
| self.logvar = logvar | |
| def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None): | |
| if temperature == 0.0: | |
| return self.mean | |
| if noise is None: | |
| noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator) | |
| else: | |
| assert noise.device == self.mean.device | |
| noise = noise.to(self.mean.dtype) | |
| if temperature != 1.0: | |
| raise NotImplementedError(f"Temperature {temperature} is not supported.") | |
| # Just Gaussian sample with no scaling of variance. | |
| return noise * torch.exp(self.logvar * 0.5) + self.mean | |
| def mode(self): | |
| return self.mean | |