modeling_simple / .venv /lib /python3.14 /site-packages /torch /distributions /generalized_pareto.py
| # mypy: allow-untyped-defs | |
| import math | |
| from numbers import Number, Real | |
| import torch | |
| from torch import inf, nan | |
| from torch.distributions import constraints, Distribution | |
| from torch.distributions.utils import broadcast_all | |
| __all__ = ["GeneralizedPareto"] | |
| class GeneralizedPareto(Distribution): | |
| r""" | |
| Creates a Generalized Pareto distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`concentration`. | |
| The Generalized Pareto distribution is a family of continuous probability distributions on the real line. | |
| Special cases include Exponential (when :attr:`loc` = 0, :attr:`concentration` = 0), Pareto (when :attr:`concentration` > 0, | |
| :attr:`loc` = :attr:`scale` / :attr:`concentration`), and Uniform (when :attr:`concentration` = -1). | |
| This distribution is often used to model the tails of other distributions. This implementation is based on the | |
| implementation in TensorFlow Probability. | |
| Example:: | |
| >>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
| >>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4])) | |
| >>> m.sample() # sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4 | |
| tensor([ 1.5623]) | |
| Args: | |
| loc (float or Tensor): Location parameter of the distribution | |
| scale (float or Tensor): Scale parameter of the distribution | |
| concentration (float or Tensor): Concentration parameter of the distribution | |
| """ | |
| arg_constraints = { | |
| "loc": constraints.real, | |
| "scale": constraints.positive, | |
| "concentration": constraints.real, | |
| } | |
| has_rsample = True | |
| def __init__(self, loc, scale, concentration, validate_args=None): | |
| self.loc, self.scale, self.concentration = broadcast_all( | |
| loc, scale, concentration | |
| ) | |
| if ( | |
| isinstance(loc, Number) | |
| and isinstance(scale, Number) | |
| and isinstance(concentration, Number) | |
| ): | |
| batch_shape = torch.Size() | |
| else: | |
| batch_shape = self.loc.size() | |
| super().__init__(batch_shape, validate_args=validate_args) | |
| def expand(self, batch_shape, _instance=None): | |
| new = self._get_checked_instance(GeneralizedPareto, _instance) | |
| batch_shape = torch.Size(batch_shape) | |
| new.loc = self.loc.expand(batch_shape) | |
| new.scale = self.scale.expand(batch_shape) | |
| new.concentration = self.concentration.expand(batch_shape) | |
| super(GeneralizedPareto, new).__init__(batch_shape, validate_args=False) | |
| new._validate_args = self._validate_args | |
| return new | |
| def rsample(self, sample_shape=torch.Size()): | |
| shape = self._extended_shape(sample_shape) | |
| u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) | |
| return self.icdf(u) | |
| def log_prob(self, value): | |
| if self._validate_args: | |
| self._validate_sample(value) | |
| z = self._z(value) | |
| eq_zero = torch.isclose(self.concentration, torch.tensor(0.0)) | |
| safe_conc = torch.where( | |
| eq_zero, torch.ones_like(self.concentration), self.concentration | |
| ) | |
| y = 1 / safe_conc + torch.ones_like(z) | |
| where_nonzero = torch.where(y == 0, y, y * torch.log1p(safe_conc * z)) | |
| log_scale = ( | |
| math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() | |
| ) | |
| return -log_scale - torch.where(eq_zero, z, where_nonzero) | |
| def log_survival_function(self, value): | |
| if self._validate_args: | |
| self._validate_sample(value) | |
| z = self._z(value) | |
| eq_zero = torch.isclose(self.concentration, torch.tensor(0.0)) | |
| safe_conc = torch.where( | |
| eq_zero, torch.ones_like(self.concentration), self.concentration | |
| ) | |
| where_nonzero = -torch.log1p(safe_conc * z) / safe_conc | |
| return torch.where(eq_zero, -z, where_nonzero) | |
| def log_cdf(self, value): | |
| return torch.log1p(-torch.exp(self.log_survival_function(value))) | |
| def cdf(self, value): | |
| return torch.exp(self.log_cdf(value)) | |
| def icdf(self, value): | |
| loc = self.loc | |
| scale = self.scale | |
| concentration = self.concentration | |
| eq_zero = torch.isclose(concentration, torch.zeros_like(concentration)) | |
| safe_conc = torch.where(eq_zero, torch.ones_like(concentration), concentration) | |
| logu = torch.log1p(-value) | |
| where_nonzero = loc + scale / safe_conc * torch.expm1(-safe_conc * logu) | |
| where_zero = loc - scale * logu | |
| return torch.where(eq_zero, where_zero, where_nonzero) | |
| def _z(self, x): | |
| return (x - self.loc) / self.scale | |
| def mean(self): | |
| concentration = self.concentration | |
| valid = concentration < 1 | |
| safe_conc = torch.where(valid, concentration, 0.5) | |
| result = self.loc + self.scale / (1 - safe_conc) | |
| return torch.where(valid, result, nan) | |
| def variance(self): | |
| concentration = self.concentration | |
| valid = concentration < 0.5 | |
| safe_conc = torch.where(valid, concentration, 0.25) | |
| result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc)) | |
| return torch.where(valid, result, nan) | |
| def entropy(self): | |
| ans = torch.log(self.scale) + self.concentration + 1 | |
| return torch.broadcast_to(ans, self._batch_shape) | |
| def mode(self): | |
| return self.loc | |
| def support(self): | |
| lower = self.loc | |
| upper = torch.where( | |
| self.concentration < 0, lower - self.scale / self.concentration, inf | |
| ) | |
| return constraints.interval(lower, upper) | |