Spaces:
Sleeping
Sleeping
| """ | |
| Contains distribution models used as parts of other networks. These | |
| classes usually inherit or emulate torch distributions. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributions as D | |
| class TanhWrappedDistribution(D.Distribution): | |
| """ | |
| Class that wraps another valid torch distribution, such that sampled values from the base distribution are | |
| passed through a tanh layer. The corresponding (log) probabilities are also modified accordingly. | |
| Tanh Normal distribution - adapted from rlkit and CQL codebase | |
| (https://github.com/aviralkumar2907/CQL/blob/d67dbe9cf5d2b96e3b462b6146f249b3d6569796/d4rl/rlkit/torch/distributions.py#L6). | |
| """ | |
| def __init__(self, base_dist, scale=1.0, epsilon=1e-6): | |
| """ | |
| Args: | |
| base_dist (Distribution): Distribution to wrap with tanh output | |
| scale (float): Scale of output | |
| epsilon (float): Numerical stability epsilon when computing log-prob. | |
| """ | |
| self.base_dist = base_dist | |
| self.scale = scale | |
| self.tanh_epsilon = epsilon | |
| super(TanhWrappedDistribution, self).__init__() | |
| def log_prob(self, value, pre_tanh_value=None): | |
| """ | |
| Args: | |
| value (torch.Tensor): some tensor to compute log probabilities for | |
| pre_tanh_value: If specified, will not calculate atanh manually from @value. More numerically stable | |
| """ | |
| value = value / self.scale | |
| if pre_tanh_value is None: | |
| one_plus_x = (1. + value).clamp(min=self.tanh_epsilon) | |
| one_minus_x = (1. - value).clamp(min=self.tanh_epsilon) | |
| pre_tanh_value = 0.5 * torch.log(one_plus_x / one_minus_x) | |
| lp = self.base_dist.log_prob(pre_tanh_value) | |
| tanh_lp = torch.log(1 - value * value + self.tanh_epsilon) | |
| # In case the base dist already sums up the log probs, make sure we do the same | |
| return lp - tanh_lp if len(lp.shape) == len(tanh_lp.shape) else lp - tanh_lp.sum(-1) | |
| def sample(self, sample_shape=torch.Size(), return_pretanh_value=False): | |
| """ | |
| Gradients will and should *not* pass through this operation. | |
| See https://github.com/pytorch/pytorch/issues/4620 for discussion. | |
| """ | |
| z = self.base_dist.sample(sample_shape=sample_shape).detach() | |
| if return_pretanh_value: | |
| return torch.tanh(z) * self.scale, z | |
| else: | |
| return torch.tanh(z) * self.scale | |
| def rsample(self, sample_shape=torch.Size(), return_pretanh_value=False): | |
| """ | |
| Sampling in the reparameterization case - for differentiable samples. | |
| """ | |
| z = self.base_dist.rsample(sample_shape=sample_shape) | |
| if return_pretanh_value: | |
| return torch.tanh(z) * self.scale, z | |
| else: | |
| return torch.tanh(z) * self.scale | |
| def mean(self): | |
| return self.base_dist.mean | |
| def stddev(self): | |
| return self.base_dist.stddev | |
| class DiscreteValueDistribution(object): | |
| """ | |
| Extension to torch categorical probability distribution in order to keep track | |
| of the support (categorical values, or in this case, value atoms). This is | |
| used for distributional value networks. | |
| """ | |
| def __init__(self, values, probs=None, logits=None): | |
| """ | |
| Creates a categorical distribution parameterized by either @probs or | |
| @logits (but not both). Expects inputs to be consistent in shape | |
| for broadcasting operations (e.g. multiplication). | |
| """ | |
| self._values = values | |
| self._categorical_dist = D.Categorical(probs=probs, logits=logits) | |
| def values(self): | |
| return self._values | |
| def probs(self): | |
| return self._categorical_dist.probs | |
| def logits(self): | |
| return self._categorical_dist.logits | |
| def mean(self): | |
| """ | |
| Categorical distribution mean, taking the value support into account. | |
| """ | |
| return (self._categorical_dist.probs * self._values).sum(dim=-1) | |
| def variance(self): | |
| """ | |
| Categorical distribution variance, taking the value support into account. | |
| """ | |
| dist_squared = (self.mean().unsqueeze(-1) - self.values).pow(2) | |
| return (self._categorical_dist.probs * dist_squared).sum(dim=-1) | |
| def sample(self, sample_shape=torch.Size()): | |
| """ | |
| Sample from the distribution. Make sure to return value atoms, not categorical class indices. | |
| """ | |
| inds = self._categorical_dist.sample(sample_shape=sample_shape) | |
| return torch.gather(self.values, inds, dim=-1) | |