xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
This file contains a collection of useful loss functions for use with torch tensors.
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
def cosine_loss(preds, labels):
"""
Cosine loss between two tensors.
Args:
preds (torch.Tensor): torch tensor
labels (torch.Tensor): torch tensor
Returns:
loss (torch.Tensor): cosine loss
"""
sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels)
return -torch.mean(sim - 1.0)
def KLD_0_1_loss(mu, logvar):
"""
KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that
this function averages across the batch dimension, but sums across dimension.
Args:
mu (torch.Tensor): mean tensor of shape (B, D)
logvar (torch.Tensor): logvar tensor of shape (B, D)
Returns:
loss (torch.Tensor): KL divergence loss between the input gaussian distribution
and N(0, 1)
"""
return -0.5 * (1. + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean()
def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2):
"""
KL divergence loss between two Gaussian distributions. This function
computes the average loss across the batch.
Args:
mu_1 (torch.Tensor): first means tensor of shape (B, D)
logvar_1 (torch.Tensor): first logvars tensor of shape (B, D)
mu_2 (torch.Tensor): second means tensor of shape (B, D)
logvar_2 (torch.Tensor): second logvars tensor of shape (B, D)
Returns:
loss (torch.Tensor): KL divergence loss between the two gaussian distributions
"""
return -0.5 * (1. + \
logvar_1 - logvar_2 \
- ((mu_2 - mu_1).pow(2) / logvar_2.exp()) \
- (logvar_1.exp() / logvar_2.exp()) \
).sum(dim=1).mean()
def log_normal(x, m, v):
"""
Log probability of tensor x under diagonal multivariate normal with
mean m and variance v. The last dimension of the tensors is treated
as the dimension of the Gaussian distribution - all other dimensions
are treated as independent Gaussians. Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): tensor with shape (B, ..., D)
m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D)
v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D)
Returns:
log_prob (torch.Tensor): log probabilities of shape (B, ...)
"""
element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi))
log_prob = element_wise.sum(-1)
return log_prob
def log_normal_mixture(x, m, v, w=None, log_w=None):
"""
Log probability of tensor x under a uniform mixture of Gaussians.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): tensor with shape (B, D)
m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where
M is number of mixture components
v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where
M is number of mixture components
w (torch.Tensor): weights tensor - if provided, should be
shape (B, M) or (1, M)
log_w (torch.Tensor): log-weights tensor - if provided, should be
shape (B, M) or (1, M)
Returns:
log_prob (torch.Tensor): log probabilities of shape (B,)
"""
# (B , D) -> (B , 1, D)
x = x.unsqueeze(1)
# (B, 1, D) -> (B, M, D) -> (B, M)
log_prob = log_normal(x, m, v)
if w is not None or log_w is not None:
# this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i))
if w is not None:
assert log_w is None
log_w = torch.log(w)
log_prob += log_w
# then compute log sum_i exp [log(w_i * N(x | m_i, v_i))]
# (B, M) -> (B,)
log_prob = log_sum_exp(log_prob , dim=1)
else:
# (B, M) -> (B,)
log_prob = log_mean_exp(log_prob , dim=1) # mean accounts for uniform weights
return log_prob
def log_mean_exp(x, dim):
"""
Compute the log(mean(exp(x), dim)) in a numerically stable manner.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): a tensor
dim (int): dimension along which mean is computed
Returns:
y (torch.Tensor): log(mean(exp(x), dim))
"""
return log_sum_exp(x, dim) - np.log(x.size(dim))
def log_sum_exp(x, dim=0):
"""
Compute the log(sum(exp(x), dim)) in a numerically stable manner.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): a tensor
dim (int): dimension along which sum is computed
Returns:
y (torch.Tensor): log(sum(exp(x), dim))
"""
max_x = torch.max(x, dim)[0]
new_x = x - max_x.unsqueeze(dim).expand_as(x)
return max_x + (new_x.exp().sum(dim)).log()
def project_values_onto_atoms(values, probabilities, atoms):
"""
Project the categorical distribution given by @probabilities on the
grid of values given by @values onto a grid of values given by @atoms.
This is useful when computing a bellman backup where the backed up
values from the original grid will not be in the original support,
requiring L2 projection.
Each value in @values has a corresponding probability in @probabilities -
this probability mass is shifted to the closest neighboring grid points in
@atoms in proportion. For example, if the value in question is 0.2, and the
neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to
atom 0 and 0.2 of the probability weight will go to 1.
Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42
Args:
values: value grid to project, of shape (batch_size, n_atoms)
probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms)
atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms)
Returns:
new probability vectors that correspond to the L2 projection of the categorical distribution
onto @atoms
"""
# make sure @atoms is shape (n_atoms,)
if len(atoms.shape) > 1:
atoms = atoms.squeeze(0)
# helper tensors from @atoms
vmin, vmax = atoms[0], atoms[1]
d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:]
d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1]
# ensure that @values grid is within the support of @atoms
clipped_values = values.clamp(min=vmin, max=vmax)[:, None, :] # (batch_size, 1, n_atoms)
clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1)
# distance between atom values in support
d_pos = (d_pos - atoms)[None, :, None] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1)
d_neg = (atoms - d_neg)[None, :, None] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1)
# distances between all pairs of grid values
deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms)
# computes eqn (7) in distributional RL paper by doing the following - for each
# output atom in @atoms, consider values that are close enough, and weight their
# probability mass contribution by the normalized distance in [0, 1] given
# by (1. - (z_j - z_i) / (delta_z)).
d_sign = (deltas >= 0.).float()
delta_hat = (d_sign * deltas / d_pos) - ((1. - d_sign) * deltas / d_neg)
delta_hat = (1. - delta_hat).clamp(min=0., max=1.)
probabilities = probabilities[:, None, :]
return (delta_hat * probabilities).sum(dim=2)