| import math |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.distributions.normal import Normal |
|
|
|
|
| def gaussian_loss(y_hat, y, log_std_min=-7.0): |
| assert y_hat.dim() == 3 |
| assert y_hat.size(2) == 2 |
| mean = y_hat[:, :, :1] |
| log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) |
| |
| log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp(-2.0 * log_std)) |
| return log_probs.squeeze().mean() |
|
|
|
|
| def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): |
| assert y_hat.size(2) == 2 |
| mean = y_hat[:, :, :1] |
| log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) |
| dist = Normal( |
| mean, |
| torch.exp(log_std), |
| ) |
| sample = dist.sample() |
| sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) |
| del dist |
| return sample |
|
|
|
|
| def log_sum_exp(x): |
| """numerically stable log_sum_exp implementation that prevents overflow""" |
| |
| axis = len(x.size()) - 1 |
| m, _ = torch.max(x, dim=axis) |
| m2, _ = torch.max(x, dim=axis, keepdim=True) |
| return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) |
|
|
|
|
| |
| def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): |
| if log_scale_min is None: |
| log_scale_min = float(np.log(1e-14)) |
| y_hat = y_hat.permute(0, 2, 1) |
| assert y_hat.dim() == 3 |
| assert y_hat.size(1) % 3 == 0 |
| nr_mix = y_hat.size(1) // 3 |
|
|
| |
| y_hat = y_hat.transpose(1, 2) |
|
|
| |
| logit_probs = y_hat[:, :, :nr_mix] |
| means = y_hat[:, :, nr_mix : 2 * nr_mix] |
| log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) |
|
|
| |
| y = y.expand_as(means) |
|
|
| centered_y = y - means |
| inv_stdv = torch.exp(-log_scales) |
| plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) |
| cdf_plus = torch.sigmoid(plus_in) |
| min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) |
| cdf_min = torch.sigmoid(min_in) |
|
|
| |
| |
| log_cdf_plus = plus_in - F.softplus(plus_in) |
|
|
| |
| |
| log_one_minus_cdf_min = -F.softplus(min_in) |
|
|
| |
| cdf_delta = cdf_plus - cdf_min |
|
|
| mid_in = inv_stdv * centered_y |
| |
| |
| log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| inner_inner_cond = (cdf_delta > 1e-5).float() |
|
|
| inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( |
| log_pdf_mid - np.log((num_classes - 1) / 2) |
| ) |
| inner_cond = (y > 0.999).float() |
| inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out |
| cond = (y < -0.999).float() |
| log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out |
|
|
| log_probs = log_probs + F.log_softmax(logit_probs, -1) |
|
|
| if reduce: |
| return -torch.mean(log_sum_exp(log_probs)) |
| return -log_sum_exp(log_probs).unsqueeze(-1) |
|
|
|
|
| def sample_from_discretized_mix_logistic(y, log_scale_min=None): |
| """ |
| Sample from discretized mixture of logistic distributions |
| Args: |
| y (Tensor): :math:`[B, C, T]` |
| log_scale_min (float): Log scale minimum value |
| Returns: |
| Tensor: sample in range of [-1, 1]. |
| """ |
| if log_scale_min is None: |
| log_scale_min = float(np.log(1e-14)) |
| assert y.size(1) % 3 == 0 |
| nr_mix = y.size(1) // 3 |
|
|
| |
| y = y.transpose(1, 2) |
| logit_probs = y[:, :, :nr_mix] |
|
|
| |
| temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) |
| temp = logit_probs.data - torch.log(-torch.log(temp)) |
| _, argmax = temp.max(dim=-1) |
|
|
| |
| one_hot = to_one_hot(argmax, nr_mix) |
| |
| means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) |
| log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) |
| |
| |
| u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) |
| x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) |
|
|
| x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) |
|
|
| return x |
|
|
|
|
| def to_one_hot(tensor, n, fill_with=1.0): |
| |
| one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor) |
| one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) |
| return one_hot |
|
|