| | |
| | import warnings |
| | from typing import Optional |
| | from typing_extensions import deprecated |
| |
|
| | import torch |
| | from torch import Tensor |
| | from torch.distributions import constraints |
| | from torch.distributions.utils import lazy_property |
| | from torch.types import _size |
| |
|
| |
|
| | __all__ = ["Distribution"] |
| |
|
| |
|
| | class Distribution: |
| | r""" |
| | Distribution is the abstract base class for probability distributions. |
| | |
| | Args: |
| | batch_shape (torch.Size): The shape over which parameters are batched. |
| | event_shape (torch.Size): The shape of a single sample (without batching). |
| | validate_args (bool, optional): Whether to validate arguments. Default: None. |
| | """ |
| |
|
| | has_rsample = False |
| | has_enumerate_support = False |
| | _validate_args = __debug__ |
| |
|
| | @staticmethod |
| | def set_default_validate_args(value: bool) -> None: |
| | """ |
| | Sets whether validation is enabled or disabled. |
| | |
| | The default behavior mimics Python's ``assert`` statement: validation |
| | is on by default, but is disabled if Python is run in optimized mode |
| | (via ``python -O``). Validation may be expensive, so you may want to |
| | disable it once a model is working. |
| | |
| | Args: |
| | value (bool): Whether to enable validation. |
| | """ |
| | if value not in [True, False]: |
| | raise ValueError |
| | Distribution._validate_args = value |
| |
|
| | def __init__( |
| | self, |
| | batch_shape: torch.Size = torch.Size(), |
| | event_shape: torch.Size = torch.Size(), |
| | validate_args: Optional[bool] = None, |
| | ) -> None: |
| | self._batch_shape = batch_shape |
| | self._event_shape = event_shape |
| | if validate_args is not None: |
| | self._validate_args = validate_args |
| | if self._validate_args: |
| | try: |
| | arg_constraints = self.arg_constraints |
| | except NotImplementedError: |
| | arg_constraints = {} |
| | warnings.warn( |
| | f"{self.__class__} does not define `arg_constraints`. " |
| | + "Please set `arg_constraints = {}` or initialize the distribution " |
| | + "with `validate_args=False` to turn off validation." |
| | ) |
| | for param, constraint in arg_constraints.items(): |
| | if constraints.is_dependent(constraint): |
| | continue |
| | if param not in self.__dict__ and isinstance( |
| | getattr(type(self), param), lazy_property |
| | ): |
| | continue |
| | value = getattr(self, param) |
| | valid = constraint.check(value) |
| | if not torch._is_all_true(valid): |
| | raise ValueError( |
| | f"Expected parameter {param} " |
| | f"({type(value).__name__} of shape {tuple(value.shape)}) " |
| | f"of distribution {repr(self)} " |
| | f"to satisfy the constraint {repr(constraint)}, " |
| | f"but found invalid values:\n{value}" |
| | ) |
| | super().__init__() |
| |
|
| | def expand(self, batch_shape: _size, _instance=None): |
| | """ |
| | Returns a new distribution instance (or populates an existing instance |
| | provided by a derived class) with batch dimensions expanded to |
| | `batch_shape`. This method calls :class:`~torch.Tensor.expand` on |
| | the distribution's parameters. As such, this does not allocate new |
| | memory for the expanded distribution instance. Additionally, |
| | this does not repeat any args checking or parameter broadcasting in |
| | `__init__.py`, when an instance is first created. |
| | |
| | Args: |
| | batch_shape (torch.Size): the desired expanded size. |
| | _instance: new instance provided by subclasses that |
| | need to override `.expand`. |
| | |
| | Returns: |
| | New distribution instance with batch dimensions expanded to |
| | `batch_size`. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def batch_shape(self) -> torch.Size: |
| | """ |
| | Returns the shape over which parameters are batched. |
| | """ |
| | return self._batch_shape |
| |
|
| | @property |
| | def event_shape(self) -> torch.Size: |
| | """ |
| | Returns the shape of a single sample (without batching). |
| | """ |
| | return self._event_shape |
| |
|
| | @property |
| | def arg_constraints(self) -> dict[str, constraints.Constraint]: |
| | """ |
| | Returns a dictionary from argument names to |
| | :class:`~torch.distributions.constraints.Constraint` objects that |
| | should be satisfied by each argument of this distribution. Args that |
| | are not tensors need not appear in this dict. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def support(self) -> Optional[constraints.Constraint]: |
| | """ |
| | Returns a :class:`~torch.distributions.constraints.Constraint` object |
| | representing this distribution's support. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def mean(self) -> Tensor: |
| | """ |
| | Returns the mean of the distribution. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def mode(self) -> Tensor: |
| | """ |
| | Returns the mode of the distribution. |
| | """ |
| | raise NotImplementedError(f"{self.__class__} does not implement mode") |
| |
|
| | @property |
| | def variance(self) -> Tensor: |
| | """ |
| | Returns the variance of the distribution. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @property |
| | def stddev(self) -> Tensor: |
| | """ |
| | Returns the standard deviation of the distribution. |
| | """ |
| | return self.variance.sqrt() |
| |
|
| | def sample(self, sample_shape: _size = torch.Size()) -> Tensor: |
| | """ |
| | Generates a sample_shape shaped sample or sample_shape shaped batch of |
| | samples if the distribution parameters are batched. |
| | """ |
| | with torch.no_grad(): |
| | return self.rsample(sample_shape) |
| |
|
| | def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: |
| | """ |
| | Generates a sample_shape shaped reparameterized sample or sample_shape |
| | shaped batch of reparameterized samples if the distribution parameters |
| | are batched. |
| | """ |
| | raise NotImplementedError |
| |
|
| | @deprecated( |
| | "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", |
| | category=FutureWarning, |
| | ) |
| | def sample_n(self, n: int) -> Tensor: |
| | """ |
| | Generates n samples or n batches of samples if the distribution |
| | parameters are batched. |
| | """ |
| | return self.sample(torch.Size((n,))) |
| |
|
| | def log_prob(self, value: Tensor) -> Tensor: |
| | """ |
| | Returns the log of the probability density/mass function evaluated at |
| | `value`. |
| | |
| | Args: |
| | value (Tensor): |
| | """ |
| | raise NotImplementedError |
| |
|
| | def cdf(self, value: Tensor) -> Tensor: |
| | """ |
| | Returns the cumulative density/mass function evaluated at |
| | `value`. |
| | |
| | Args: |
| | value (Tensor): |
| | """ |
| | raise NotImplementedError |
| |
|
| | def icdf(self, value: Tensor) -> Tensor: |
| | """ |
| | Returns the inverse cumulative density/mass function evaluated at |
| | `value`. |
| | |
| | Args: |
| | value (Tensor): |
| | """ |
| | raise NotImplementedError |
| |
|
| | def enumerate_support(self, expand: bool = True) -> Tensor: |
| | """ |
| | Returns tensor containing all values supported by a discrete |
| | distribution. The result will enumerate over dimension 0, so the shape |
| | of the result will be `(cardinality,) + batch_shape + event_shape` |
| | (where `event_shape = ()` for univariate distributions). |
| | |
| | Note that this enumerates over all batched tensors in lock-step |
| | `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens |
| | along dim 0, but with the remaining batch dimensions being |
| | singleton dimensions, `[[0], [1], ..`. |
| | |
| | To iterate over the full Cartesian product use |
| | `itertools.product(m.enumerate_support())`. |
| | |
| | Args: |
| | expand (bool): whether to expand the support over the |
| | batch dims to match the distribution's `batch_shape`. |
| | |
| | Returns: |
| | Tensor iterating over dimension 0. |
| | """ |
| | raise NotImplementedError |
| |
|
| | def entropy(self) -> Tensor: |
| | """ |
| | Returns entropy of distribution, batched over batch_shape. |
| | |
| | Returns: |
| | Tensor of shape batch_shape. |
| | """ |
| | raise NotImplementedError |
| |
|
| | def perplexity(self) -> Tensor: |
| | """ |
| | Returns perplexity of distribution, batched over batch_shape. |
| | |
| | Returns: |
| | Tensor of shape batch_shape. |
| | """ |
| | return torch.exp(self.entropy()) |
| |
|
| | def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: |
| | """ |
| | Returns the size of the sample returned by the distribution, given |
| | a `sample_shape`. Note, that the batch and event shapes of a distribution |
| | instance are fixed at the time of construction. If this is empty, the |
| | returned shape is upcast to (1,). |
| | |
| | Args: |
| | sample_shape (torch.Size): the size of the sample to be drawn. |
| | """ |
| | if not isinstance(sample_shape, torch.Size): |
| | sample_shape = torch.Size(sample_shape) |
| | return torch.Size(sample_shape + self._batch_shape + self._event_shape) |
| |
|
| | def _validate_sample(self, value: Tensor) -> None: |
| | """ |
| | Argument validation for distribution methods such as `log_prob`, |
| | `cdf` and `icdf`. The rightmost dimensions of a value to be |
| | scored via these methods must agree with the distribution's batch |
| | and event shapes. |
| | |
| | Args: |
| | value (Tensor): the tensor whose log probability is to be |
| | computed by the `log_prob` method. |
| | Raises |
| | ValueError: when the rightmost dimensions of `value` do not match the |
| | distribution's batch and event shapes. |
| | """ |
| | if not isinstance(value, torch.Tensor): |
| | raise ValueError("The value argument to log_prob must be a Tensor") |
| |
|
| | event_dim_start = len(value.size()) - len(self._event_shape) |
| | if value.size()[event_dim_start:] != self._event_shape: |
| | raise ValueError( |
| | f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." |
| | ) |
| |
|
| | actual_shape = value.size() |
| | expected_shape = self._batch_shape + self._event_shape |
| | for i, j in zip(reversed(actual_shape), reversed(expected_shape)): |
| | if i != 1 and j != 1 and i != j: |
| | raise ValueError( |
| | f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." |
| | ) |
| | try: |
| | support = self.support |
| | except NotImplementedError: |
| | warnings.warn( |
| | f"{self.__class__} does not define `support` to enable " |
| | + "sample validation. Please initialize the distribution with " |
| | + "`validate_args=False` to turn off validation." |
| | ) |
| | return |
| | assert support is not None |
| | valid = support.check(value) |
| | if not torch._is_all_true(valid): |
| | raise ValueError( |
| | "Expected value argument " |
| | f"({type(value).__name__} of shape {tuple(value.shape)}) " |
| | f"to be within the support ({repr(support)}) " |
| | f"of the distribution {repr(self)}, " |
| | f"but found invalid values:\n{value}" |
| | ) |
| |
|
| | def _get_checked_instance(self, cls, _instance=None): |
| | if _instance is None and type(self).__init__ != cls.__init__: |
| | raise NotImplementedError( |
| | f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " |
| | "must also define a custom .expand() method." |
| | ) |
| | return self.__new__(type(self)) if _instance is None else _instance |
| |
|
| | def __repr__(self) -> str: |
| | param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] |
| | args_string = ", ".join( |
| | [ |
| | f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" |
| | for p in param_names |
| | ] |
| | ) |
| | return self.__class__.__name__ + "(" + args_string + ")" |
| |
|