File size: 4,547 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.types import _size
__all__ = ["Dirichlet"]
# This helper is exposed for testing.
def _Dirichlet_backward(x, concentration, grad_output):
total = concentration.sum(-1, True).expand_as(concentration)
grad = torch._dirichlet_grad(x, concentration, total)
return grad * (grad_output - (x * grad_output).sum(-1, True))
class _Dirichlet(Function):
@staticmethod
def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration)
ctx.save_for_backward(x, concentration)
return x
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output)
class Dirichlet(ExponentialFamily):
r"""
Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
>>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5]
tensor([ 0.1046, 0.8954])
Args:
concentration (Tensor): concentration parameter of the distribution
(often referred to as alpha)
"""
arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1)
}
support = constraints.simplex
has_rsample = True
def __init__(
self,
concentration: Tensor,
validate_args: Optional[bool] = None,
) -> None:
if concentration.dim() < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
)
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Dirichlet, _instance)
batch_shape = torch.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape + self.event_shape)
super(Dirichlet, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def rsample(self, sample_shape: _size = ()) -> Tensor:
shape = self._extended_shape(sample_shape)
concentration = self.concentration.expand(shape)
return _Dirichlet.apply(concentration)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (
torch.xlogy(self.concentration - 1.0, value).sum(-1)
+ torch.lgamma(self.concentration.sum(-1))
- torch.lgamma(self.concentration).sum(-1)
)
@property
def mean(self) -> Tensor:
return self.concentration / self.concentration.sum(-1, True)
@property
def mode(self) -> Tensor:
concentrationm1 = (self.concentration - 1).clamp(min=0.0)
mode = concentrationm1 / concentrationm1.sum(-1, True)
mask = (self.concentration < 1).all(dim=-1)
mode[mask] = torch.nn.functional.one_hot(
mode[mask].argmax(dim=-1), concentrationm1.shape[-1]
).to(mode)
return mode
@property
def variance(self) -> Tensor:
con0 = self.concentration.sum(-1, True)
return (
self.concentration
* (con0 - self.concentration)
/ (con0.pow(2) * (con0 + 1))
)
def entropy(self):
k = self.concentration.size(-1)
a0 = self.concentration.sum(-1)
return (
torch.lgamma(self.concentration).sum(-1)
- torch.lgamma(a0)
- (k - a0) * torch.digamma(a0)
- ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
)
@property
def _natural_params(self) -> tuple[Tensor]:
return (self.concentration,)
def _log_normalizer(self, x):
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|