""" This file contains a collection of useful loss functions for use with torch tensors. """ import math import numpy as np import torch import torch.nn.functional as F def cosine_loss(preds, labels): """ Cosine loss between two tensors. Args: preds (torch.Tensor): torch tensor labels (torch.Tensor): torch tensor Returns: loss (torch.Tensor): cosine loss """ sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels) return -torch.mean(sim - 1.0) def KLD_0_1_loss(mu, logvar): """ KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that this function averages across the batch dimension, but sums across dimension. Args: mu (torch.Tensor): mean tensor of shape (B, D) logvar (torch.Tensor): logvar tensor of shape (B, D) Returns: loss (torch.Tensor): KL divergence loss between the input gaussian distribution and N(0, 1) """ return -0.5 * (1. + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean() def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2): """ KL divergence loss between two Gaussian distributions. This function computes the average loss across the batch. Args: mu_1 (torch.Tensor): first means tensor of shape (B, D) logvar_1 (torch.Tensor): first logvars tensor of shape (B, D) mu_2 (torch.Tensor): second means tensor of shape (B, D) logvar_2 (torch.Tensor): second logvars tensor of shape (B, D) Returns: loss (torch.Tensor): KL divergence loss between the two gaussian distributions """ return -0.5 * (1. + \ logvar_1 - logvar_2 \ - ((mu_2 - mu_1).pow(2) / logvar_2.exp()) \ - (logvar_1.exp() / logvar_2.exp()) \ ).sum(dim=1).mean() def log_normal(x, m, v): """ Log probability of tensor x under diagonal multivariate normal with mean m and variance v. The last dimension of the tensors is treated as the dimension of the Gaussian distribution - all other dimensions are treated as independent Gaussians. Adapted from CS 236 at Stanford. Args: x (torch.Tensor): tensor with shape (B, ..., D) m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D) v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D) Returns: log_prob (torch.Tensor): log probabilities of shape (B, ...) """ element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi)) log_prob = element_wise.sum(-1) return log_prob def log_normal_mixture(x, m, v, w=None, log_w=None): """ Log probability of tensor x under a uniform mixture of Gaussians. Adapted from CS 236 at Stanford. Args: x (torch.Tensor): tensor with shape (B, D) m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where M is number of mixture components v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where M is number of mixture components w (torch.Tensor): weights tensor - if provided, should be shape (B, M) or (1, M) log_w (torch.Tensor): log-weights tensor - if provided, should be shape (B, M) or (1, M) Returns: log_prob (torch.Tensor): log probabilities of shape (B,) """ # (B , D) -> (B , 1, D) x = x.unsqueeze(1) # (B, 1, D) -> (B, M, D) -> (B, M) log_prob = log_normal(x, m, v) if w is not None or log_w is not None: # this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i)) if w is not None: assert log_w is None log_w = torch.log(w) log_prob += log_w # then compute log sum_i exp [log(w_i * N(x | m_i, v_i))] # (B, M) -> (B,) log_prob = log_sum_exp(log_prob , dim=1) else: # (B, M) -> (B,) log_prob = log_mean_exp(log_prob , dim=1) # mean accounts for uniform weights return log_prob def log_mean_exp(x, dim): """ Compute the log(mean(exp(x), dim)) in a numerically stable manner. Adapted from CS 236 at Stanford. Args: x (torch.Tensor): a tensor dim (int): dimension along which mean is computed Returns: y (torch.Tensor): log(mean(exp(x), dim)) """ return log_sum_exp(x, dim) - np.log(x.size(dim)) def log_sum_exp(x, dim=0): """ Compute the log(sum(exp(x), dim)) in a numerically stable manner. Adapted from CS 236 at Stanford. Args: x (torch.Tensor): a tensor dim (int): dimension along which sum is computed Returns: y (torch.Tensor): log(sum(exp(x), dim)) """ max_x = torch.max(x, dim)[0] new_x = x - max_x.unsqueeze(dim).expand_as(x) return max_x + (new_x.exp().sum(dim)).log() def project_values_onto_atoms(values, probabilities, atoms): """ Project the categorical distribution given by @probabilities on the grid of values given by @values onto a grid of values given by @atoms. This is useful when computing a bellman backup where the backed up values from the original grid will not be in the original support, requiring L2 projection. Each value in @values has a corresponding probability in @probabilities - this probability mass is shifted to the closest neighboring grid points in @atoms in proportion. For example, if the value in question is 0.2, and the neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to atom 0 and 0.2 of the probability weight will go to 1. Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42 Args: values: value grid to project, of shape (batch_size, n_atoms) probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms) atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms) Returns: new probability vectors that correspond to the L2 projection of the categorical distribution onto @atoms """ # make sure @atoms is shape (n_atoms,) if len(atoms.shape) > 1: atoms = atoms.squeeze(0) # helper tensors from @atoms vmin, vmax = atoms[0], atoms[1] d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:] d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1] # ensure that @values grid is within the support of @atoms clipped_values = values.clamp(min=vmin, max=vmax)[:, None, :] # (batch_size, 1, n_atoms) clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1) # distance between atom values in support d_pos = (d_pos - atoms)[None, :, None] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1) d_neg = (atoms - d_neg)[None, :, None] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1) # distances between all pairs of grid values deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms) # computes eqn (7) in distributional RL paper by doing the following - for each # output atom in @atoms, consider values that are close enough, and weight their # probability mass contribution by the normalized distance in [0, 1] given # by (1. - (z_j - z_i) / (delta_z)). d_sign = (deltas >= 0.).float() delta_hat = (d_sign * deltas / d_pos) - ((1. - d_sign) * deltas / d_neg) delta_hat = (1. - delta_hat).clamp(min=0., max=1.) probabilities = probabilities[:, None, :] return (delta_hat * probabilities).sum(dim=2)