File size: 7,576 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
This file contains a collection of useful loss functions for use with torch tensors.
"""

import math
import numpy as np
import torch
import torch.nn.functional as F


def cosine_loss(preds, labels):
    """
    Cosine loss between two tensors.

    Args:
        preds (torch.Tensor): torch tensor
        labels (torch.Tensor): torch tensor

    Returns:
        loss (torch.Tensor): cosine loss
    """
    sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels)
    return -torch.mean(sim - 1.0)


def KLD_0_1_loss(mu, logvar):
    """
    KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that 
    this function averages across the batch dimension, but sums across dimension.

    Args:
        mu (torch.Tensor): mean tensor of shape (B, D)
        logvar (torch.Tensor): logvar tensor of shape (B, D)

    Returns:
        loss (torch.Tensor): KL divergence loss between the input gaussian distribution
            and N(0, 1)
    """
    return -0.5 * (1. + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean()


def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2):
    """
    KL divergence loss between two Gaussian distributions. This function 
    computes the average loss across the batch.

    Args:
        mu_1 (torch.Tensor): first means tensor of shape (B, D)
        logvar_1 (torch.Tensor): first logvars tensor of shape (B, D)
        mu_2 (torch.Tensor): second means tensor of shape (B, D)
        logvar_2 (torch.Tensor): second logvars tensor of shape (B, D)

    Returns:
        loss (torch.Tensor): KL divergence loss between the two gaussian distributions
    """
    return -0.5 * (1. + \
        logvar_1 - logvar_2 \
        - ((mu_2 - mu_1).pow(2) / logvar_2.exp()) \
        - (logvar_1.exp() / logvar_2.exp()) \
        ).sum(dim=1).mean()


def log_normal(x, m, v):
    """
    Log probability of tensor x under diagonal multivariate normal with
    mean m and variance v. The last dimension of the tensors is treated
    as the dimension of the Gaussian distribution - all other dimensions
    are treated as independent Gaussians. Adapted from CS 236 at Stanford.

    Args:
        x (torch.Tensor): tensor with shape (B, ..., D)
        m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D)
        v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D)

    Returns:
        log_prob (torch.Tensor): log probabilities of shape (B, ...)
    """
    element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi))
    log_prob = element_wise.sum(-1)
    return log_prob


def log_normal_mixture(x, m, v, w=None, log_w=None):
    """
    Log probability of tensor x under a uniform mixture of Gaussians. 
    Adapted from CS 236 at Stanford.

    Args:
        x (torch.Tensor): tensor with shape (B, D)
        m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where 
            M is number of mixture components
        v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where 
            M is number of mixture components
        w (torch.Tensor): weights tensor - if provided, should be 
            shape (B, M) or (1, M)
        log_w (torch.Tensor): log-weights tensor - if provided, should be 
            shape (B, M) or (1, M)

    Returns:
        log_prob (torch.Tensor): log probabilities of shape (B,)
    """

    # (B , D) -> (B , 1, D)
    x = x.unsqueeze(1)
    # (B, 1, D) -> (B, M, D) -> (B, M)
    log_prob = log_normal(x, m, v)
    if w is not None or log_w is not None:
        # this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i))
        if w is not None:
            assert log_w is None
            log_w = torch.log(w)
        log_prob += log_w
        # then compute log sum_i exp [log(w_i * N(x | m_i, v_i))]
        # (B, M) -> (B,)
        log_prob = log_sum_exp(log_prob , dim=1)
    else:
        # (B, M) -> (B,)
        log_prob = log_mean_exp(log_prob , dim=1) # mean accounts for uniform weights
    return log_prob


def log_mean_exp(x, dim):
    """
    Compute the log(mean(exp(x), dim)) in a numerically stable manner.
    Adapted from CS 236 at Stanford.

    Args:
        x (torch.Tensor): a tensor 
        dim (int): dimension along which mean is computed

    Returns:
        y (torch.Tensor): log(mean(exp(x), dim))
    """
    return log_sum_exp(x, dim) - np.log(x.size(dim))


def log_sum_exp(x, dim=0):
    """
    Compute the log(sum(exp(x), dim)) in a numerically stable manner.
    Adapted from CS 236 at Stanford.

    Args:
        x (torch.Tensor): a tensor 
        dim (int): dimension along which sum is computed

    Returns:
        y (torch.Tensor): log(sum(exp(x), dim))
    """
    max_x = torch.max(x, dim)[0]
    new_x = x - max_x.unsqueeze(dim).expand_as(x)
    return max_x + (new_x.exp().sum(dim)).log()


def project_values_onto_atoms(values, probabilities, atoms):
    """
    Project the categorical distribution given by @probabilities on the
    grid of values given by @values onto a grid of values given by @atoms.
    This is useful when computing a bellman backup where the backed up
    values from the original grid will not be in the original support,
    requiring L2 projection. 

    Each value in @values has a corresponding probability in @probabilities -
    this probability mass is shifted to the closest neighboring grid points in
    @atoms in proportion. For example, if the value in question is 0.2, and the
    neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to 
    atom 0 and 0.2 of the probability weight will go to 1.

    Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42
    
    Args:
        values: value grid to project, of shape (batch_size, n_atoms)
        probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms)
        atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms)

    Returns:
        new probability vectors that correspond to the L2 projection of the categorical distribution
        onto @atoms
    """

    # make sure @atoms is shape (n_atoms,)
    if len(atoms.shape) > 1:
        atoms = atoms.squeeze(0)

    # helper tensors from @atoms
    vmin, vmax = atoms[0], atoms[1]
    d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:]
    d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1]

    # ensure that @values grid is within the support of @atoms
    clipped_values = values.clamp(min=vmin, max=vmax)[:, None, :] # (batch_size, 1, n_atoms)
    clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1)

    # distance between atom values in support
    d_pos = (d_pos - atoms)[None, :, None] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1)
    d_neg = (atoms - d_neg)[None, :, None] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1)

    # distances between all pairs of grid values
    deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms)

    # computes eqn (7) in distributional RL paper by doing the following - for each
    # output atom in @atoms, consider values that are close enough, and weight their
    # probability mass contribution by the normalized distance in [0, 1] given 
    # by (1. - (z_j - z_i) / (delta_z)).
    d_sign = (deltas >= 0.).float()
    delta_hat = (d_sign * deltas / d_pos) - ((1. - d_sign) * deltas / d_neg)
    delta_hat = (1. - delta_hat).clamp(min=0., max=1.)
    probabilities = probabilities[:, None, :]
    return (delta_hat * probabilities).sum(dim=2)