""" Contains distribution models used as parts of other networks. These classes usually inherit or emulate torch distributions. """ import torch import torch.nn as nn import torch.nn.functional as F import torch.distributions as D class TanhWrappedDistribution(D.Distribution): """ Class that wraps another valid torch distribution, such that sampled values from the base distribution are passed through a tanh layer. The corresponding (log) probabilities are also modified accordingly. Tanh Normal distribution - adapted from rlkit and CQL codebase (https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/distributions.py#L6). """ def __init__(self, base_dist, scale=1.0, epsilon=1e-6): """ Args: base_dist (Distribution): Distribution to wrap with tanh output scale (float): Scale of output epsilon (float): Numerical stability epsilon when computing log-prob. """ self.base_dist = base_dist self.scale = scale self.tanh_epsilon = epsilon super(TanhWrappedDistribution, self).__init__() def log_prob(self, value, pre_tanh_value=None): """ Args: value (torch.Tensor): some tensor to compute log probabilities for pre_tanh_value: If specified, will not calculate atanh manually from @value. More numerically stable """ value = value / self.scale if pre_tanh_value is None: one_plus_x = (1. + value).clamp(min=self.tanh_epsilon) one_minus_x = (1. - value).clamp(min=self.tanh_epsilon) pre_tanh_value = 0.5 * torch.log(one_plus_x / one_minus_x) lp = self.base_dist.log_prob(pre_tanh_value) tanh_lp = torch.log(1 - value * value + self.tanh_epsilon) # In case the base dist already sums up the log probs, make sure we do the same return lp - tanh_lp if len(lp.shape) == len(tanh_lp.shape) else lp - tanh_lp.sum(-1) def sample(self, sample_shape=torch.Size(), return_pretanh_value=False): """ Gradients will and should *not* pass through this operation. See https://github.com/pytorch/pytorch/issues/4620 for discussion. """ z = self.base_dist.sample(sample_shape=sample_shape).detach() if return_pretanh_value: return torch.tanh(z) * self.scale, z else: return torch.tanh(z) * self.scale def rsample(self, sample_shape=torch.Size(), return_pretanh_value=False): """ Sampling in the reparameterization case - for differentiable samples. """ z = self.base_dist.rsample(sample_shape=sample_shape) if return_pretanh_value: return torch.tanh(z) * self.scale, z else: return torch.tanh(z) * self.scale @property def mean(self): return self.base_dist.mean @property def stddev(self): return self.base_dist.stddev class DiscreteValueDistribution(object): """ Extension to torch categorical probability distribution in order to keep track of the support (categorical values, or in this case, value atoms). This is used for distributional value networks. """ def __init__(self, values, probs=None, logits=None): """ Creates a categorical distribution parameterized by either @probs or @logits (but not both). Expects inputs to be consistent in shape for broadcasting operations (e.g. multiplication). """ self._values = values self._categorical_dist = D.Categorical(probs=probs, logits=logits) @property def values(self): return self._values @property def probs(self): return self._categorical_dist.probs @property def logits(self): return self._categorical_dist.logits def mean(self): """ Categorical distribution mean, taking the value support into account. """ return (self._categorical_dist.probs * self._values).sum(dim=-1) def variance(self): """ Categorical distribution variance, taking the value support into account. """ dist_squared = (self.mean().unsqueeze(-1) - self.values).pow(2) return (self._categorical_dist.probs * dist_squared).sum(dim=-1) def sample(self, sample_shape=torch.Size()): """ Sample from the distribution. Make sure to return value atoms, not categorical class indices. """ inds = self._categorical_dist.sample(sample_shape=sample_shape) return torch.gather(self.values, inds, dim=-1)