Spaces:
Sleeping
Sleeping
File size: 7,576 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
"""
This file contains a collection of useful loss functions for use with torch tensors.
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
def cosine_loss(preds, labels):
"""
Cosine loss between two tensors.
Args:
preds (torch.Tensor): torch tensor
labels (torch.Tensor): torch tensor
Returns:
loss (torch.Tensor): cosine loss
"""
sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels)
return -torch.mean(sim - 1.0)
def KLD_0_1_loss(mu, logvar):
"""
KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that
this function averages across the batch dimension, but sums across dimension.
Args:
mu (torch.Tensor): mean tensor of shape (B, D)
logvar (torch.Tensor): logvar tensor of shape (B, D)
Returns:
loss (torch.Tensor): KL divergence loss between the input gaussian distribution
and N(0, 1)
"""
return -0.5 * (1. + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean()
def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2):
"""
KL divergence loss between two Gaussian distributions. This function
computes the average loss across the batch.
Args:
mu_1 (torch.Tensor): first means tensor of shape (B, D)
logvar_1 (torch.Tensor): first logvars tensor of shape (B, D)
mu_2 (torch.Tensor): second means tensor of shape (B, D)
logvar_2 (torch.Tensor): second logvars tensor of shape (B, D)
Returns:
loss (torch.Tensor): KL divergence loss between the two gaussian distributions
"""
return -0.5 * (1. + \
logvar_1 - logvar_2 \
- ((mu_2 - mu_1).pow(2) / logvar_2.exp()) \
- (logvar_1.exp() / logvar_2.exp()) \
).sum(dim=1).mean()
def log_normal(x, m, v):
"""
Log probability of tensor x under diagonal multivariate normal with
mean m and variance v. The last dimension of the tensors is treated
as the dimension of the Gaussian distribution - all other dimensions
are treated as independent Gaussians. Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): tensor with shape (B, ..., D)
m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D)
v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D)
Returns:
log_prob (torch.Tensor): log probabilities of shape (B, ...)
"""
element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi))
log_prob = element_wise.sum(-1)
return log_prob
def log_normal_mixture(x, m, v, w=None, log_w=None):
"""
Log probability of tensor x under a uniform mixture of Gaussians.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): tensor with shape (B, D)
m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where
M is number of mixture components
v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where
M is number of mixture components
w (torch.Tensor): weights tensor - if provided, should be
shape (B, M) or (1, M)
log_w (torch.Tensor): log-weights tensor - if provided, should be
shape (B, M) or (1, M)
Returns:
log_prob (torch.Tensor): log probabilities of shape (B,)
"""
# (B , D) -> (B , 1, D)
x = x.unsqueeze(1)
# (B, 1, D) -> (B, M, D) -> (B, M)
log_prob = log_normal(x, m, v)
if w is not None or log_w is not None:
# this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i))
if w is not None:
assert log_w is None
log_w = torch.log(w)
log_prob += log_w
# then compute log sum_i exp [log(w_i * N(x | m_i, v_i))]
# (B, M) -> (B,)
log_prob = log_sum_exp(log_prob , dim=1)
else:
# (B, M) -> (B,)
log_prob = log_mean_exp(log_prob , dim=1) # mean accounts for uniform weights
return log_prob
def log_mean_exp(x, dim):
"""
Compute the log(mean(exp(x), dim)) in a numerically stable manner.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): a tensor
dim (int): dimension along which mean is computed
Returns:
y (torch.Tensor): log(mean(exp(x), dim))
"""
return log_sum_exp(x, dim) - np.log(x.size(dim))
def log_sum_exp(x, dim=0):
"""
Compute the log(sum(exp(x), dim)) in a numerically stable manner.
Adapted from CS 236 at Stanford.
Args:
x (torch.Tensor): a tensor
dim (int): dimension along which sum is computed
Returns:
y (torch.Tensor): log(sum(exp(x), dim))
"""
max_x = torch.max(x, dim)[0]
new_x = x - max_x.unsqueeze(dim).expand_as(x)
return max_x + (new_x.exp().sum(dim)).log()
def project_values_onto_atoms(values, probabilities, atoms):
"""
Project the categorical distribution given by @probabilities on the
grid of values given by @values onto a grid of values given by @atoms.
This is useful when computing a bellman backup where the backed up
values from the original grid will not be in the original support,
requiring L2 projection.
Each value in @values has a corresponding probability in @probabilities -
this probability mass is shifted to the closest neighboring grid points in
@atoms in proportion. For example, if the value in question is 0.2, and the
neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to
atom 0 and 0.2 of the probability weight will go to 1.
Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42
Args:
values: value grid to project, of shape (batch_size, n_atoms)
probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms)
atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms)
Returns:
new probability vectors that correspond to the L2 projection of the categorical distribution
onto @atoms
"""
# make sure @atoms is shape (n_atoms,)
if len(atoms.shape) > 1:
atoms = atoms.squeeze(0)
# helper tensors from @atoms
vmin, vmax = atoms[0], atoms[1]
d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:]
d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1]
# ensure that @values grid is within the support of @atoms
clipped_values = values.clamp(min=vmin, max=vmax)[:, None, :] # (batch_size, 1, n_atoms)
clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1)
# distance between atom values in support
d_pos = (d_pos - atoms)[None, :, None] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1)
d_neg = (atoms - d_neg)[None, :, None] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1)
# distances between all pairs of grid values
deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms)
# computes eqn (7) in distributional RL paper by doing the following - for each
# output atom in @atoms, consider values that are close enough, and weight their
# probability mass contribution by the normalized distance in [0, 1] given
# by (1. - (z_j - z_i) / (delta_z)).
d_sign = (deltas >= 0.).float()
delta_hat = (d_sign * deltas / d_pos) - ((1. - d_sign) * deltas / d_neg)
delta_hat = (1. - delta_hat).clamp(min=0., max=1.)
probabilities = probabilities[:, None, :]
return (delta_hat * probabilities).sum(dim=2)
|