Spaces:
Sleeping
Sleeping
| import math | |
| from tqdm import trange, tqdm | |
| import torch | |
| def matrix_log_density_gaussian(x, mu, logvar): | |
| """Calculates log density of a Gaussian for all combination of bacth pairs of | |
| `x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)` | |
| instead of (batch_size, dim) in the usual log density. | |
| Parameters | |
| ---------- | |
| x: torch.Tensor | |
| Value at which to compute the density. Shape: (batch_size, dim). | |
| mu: torch.Tensor | |
| Mean. Shape: (batch_size, dim). | |
| logvar: torch.Tensor | |
| Log variance. Shape: (batch_size, dim). | |
| batch_size: int | |
| number of training images in the batch | |
| """ | |
| batch_size, dim = x.shape | |
| x = x.view(batch_size, 1, dim) | |
| mu = mu.view(1, batch_size, dim) | |
| logvar = logvar.view(1, batch_size, dim) | |
| return log_density_gaussian(x, mu, logvar) | |
| def log_density_gaussian(x, mu, logvar): | |
| """Calculates log density of a Gaussian. | |
| Parameters | |
| ---------- | |
| x: torch.Tensor or np.ndarray or float | |
| Value at which to compute the density. | |
| mu: torch.Tensor or np.ndarray or float | |
| Mean. | |
| logvar: torch.Tensor or np.ndarray or float | |
| Log variance. | |
| """ | |
| normalization = - 0.5 * (math.log(2 * math.pi) + logvar) | |
| inv_var = torch.exp(-logvar) | |
| log_density = normalization - 0.5 * ((x - mu)**2 * inv_var) | |
| return log_density | |
| def log_importance_weight_matrix(batch_size, dataset_size): | |
| """ | |
| Calculates a log importance weight matrix | |
| Parameters | |
| ---------- | |
| batch_size: int | |
| number of training images in the batch | |
| dataset_size: int | |
| number of training images in the dataset | |
| """ | |
| N = dataset_size | |
| M = batch_size - 1 | |
| strat_weight = (N - M) / (N * M) | |
| W = torch.Tensor(batch_size, batch_size).fill_(1 / M) | |
| W.view(-1)[::M + 1] = 1 / N | |
| W.view(-1)[1::M + 1] = strat_weight | |
| W[M - 1, 0] = strat_weight | |
| return W.log() | |