Phi2-Fine-Tuning
/
phivenv
/Lib
/site-packages
/torch
/distributions
/lowrank_multivariate_normal.py
| # mypy: allow-untyped-defs | |
| import math | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| from torch.distributions import constraints | |
| from torch.distributions.distribution import Distribution | |
| from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv | |
| from torch.distributions.utils import _standard_normal, lazy_property | |
| from torch.types import _size | |
| __all__ = ["LowRankMultivariateNormal"] | |
| def _batch_capacitance_tril(W, D): | |
| r""" | |
| Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` | |
| and a batch of vectors :math:`D`. | |
| """ | |
| m = W.size(-1) | |
| Wt_Dinv = W.mT / D.unsqueeze(-2) | |
| K = torch.matmul(Wt_Dinv, W).contiguous() | |
| K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K | |
| return torch.linalg.cholesky(K) | |
| def _batch_lowrank_logdet(W, D, capacitance_tril): | |
| r""" | |
| Uses "matrix determinant lemma":: | |
| log|W @ W.T + D| = log|C| + log|D|, | |
| where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute | |
| the log determinant. | |
| """ | |
| return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum( | |
| -1 | |
| ) | |
| def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): | |
| r""" | |
| Uses "Woodbury matrix identity":: | |
| inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), | |
| where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared | |
| Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. | |
| """ | |
| Wt_Dinv = W.mT / D.unsqueeze(-2) | |
| Wt_Dinv_x = _batch_mv(Wt_Dinv, x) | |
| mahalanobis_term1 = (x.pow(2) / D).sum(-1) | |
| mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) | |
| return mahalanobis_term1 - mahalanobis_term2 | |
| class LowRankMultivariateNormal(Distribution): | |
| r""" | |
| Creates a multivariate normal distribution with covariance matrix having a low-rank form | |
| parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: | |
| covariance_matrix = cov_factor @ cov_factor.T + cov_diag | |
| Example: | |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) | |
| >>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
| >>> m = LowRankMultivariateNormal( | |
| ... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2) | |
| ... ) | |
| >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` | |
| tensor([-0.2102, -0.5429]) | |
| Args: | |
| loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` | |
| cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape | |
| `batch_shape + event_shape + (rank,)` | |
| cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape | |
| `batch_shape + event_shape` | |
| Note: | |
| The computation for determinant and inverse of covariance matrix is avoided when | |
| `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity | |
| <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and | |
| `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_. | |
| Thanks to these formulas, we just need to compute the determinant and inverse of | |
| the small size "capacitance" matrix:: | |
| capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor | |
| """ | |
| arg_constraints = { | |
| "loc": constraints.real_vector, | |
| "cov_factor": constraints.independent(constraints.real, 2), | |
| "cov_diag": constraints.independent(constraints.positive, 1), | |
| } | |
| support = constraints.real_vector | |
| has_rsample = True | |
| def __init__( | |
| self, | |
| loc: Tensor, | |
| cov_factor: Tensor, | |
| cov_diag: Tensor, | |
| validate_args: Optional[bool] = None, | |
| ) -> None: | |
| if loc.dim() < 1: | |
| raise ValueError("loc must be at least one-dimensional.") | |
| event_shape = loc.shape[-1:] | |
| if cov_factor.dim() < 2: | |
| raise ValueError( | |
| "cov_factor must be at least two-dimensional, " | |
| "with optional leading batch dimensions" | |
| ) | |
| if cov_factor.shape[-2:-1] != event_shape: | |
| raise ValueError( | |
| f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" | |
| ) | |
| if cov_diag.shape[-1:] != event_shape: | |
| raise ValueError( | |
| f"cov_diag must be a batch of vectors with shape {event_shape}" | |
| ) | |
| loc_ = loc.unsqueeze(-1) | |
| cov_diag_ = cov_diag.unsqueeze(-1) | |
| try: | |
| loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( | |
| loc_, cov_factor, cov_diag_ | |
| ) | |
| except RuntimeError as e: | |
| raise ValueError( | |
| f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" | |
| ) from e | |
| self.loc = loc_[..., 0] | |
| self.cov_diag = cov_diag_[..., 0] | |
| batch_shape = self.loc.shape[:-1] | |
| self._unbroadcasted_cov_factor = cov_factor | |
| self._unbroadcasted_cov_diag = cov_diag | |
| self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) | |
| super().__init__(batch_shape, event_shape, validate_args=validate_args) | |
| def expand(self, batch_shape, _instance=None): | |
| new = self._get_checked_instance(LowRankMultivariateNormal, _instance) | |
| batch_shape = torch.Size(batch_shape) | |
| loc_shape = batch_shape + self.event_shape | |
| new.loc = self.loc.expand(loc_shape) | |
| new.cov_diag = self.cov_diag.expand(loc_shape) | |
| new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:]) | |
| new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor | |
| new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag | |
| new._capacitance_tril = self._capacitance_tril | |
| super(LowRankMultivariateNormal, new).__init__( | |
| batch_shape, self.event_shape, validate_args=False | |
| ) | |
| new._validate_args = self._validate_args | |
| return new | |
| def mean(self) -> Tensor: | |
| return self.loc | |
| def mode(self) -> Tensor: | |
| return self.loc | |
| def variance(self) -> Tensor: # type: ignore[override] | |
| return ( | |
| self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag | |
| ).expand(self._batch_shape + self._event_shape) | |
| def scale_tril(self) -> Tensor: | |
| # The following identity is used to increase the numerically computation stability | |
| # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): | |
| # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 | |
| # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, | |
| # hence it is well-conditioned and safe to take Cholesky decomposition. | |
| n = self._event_shape[0] | |
| cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) | |
| Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze | |
| K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() | |
| K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K | |
| scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) | |
| return scale_tril.expand( | |
| self._batch_shape + self._event_shape + self._event_shape | |
| ) | |
| def covariance_matrix(self) -> Tensor: | |
| covariance_matrix = torch.matmul( | |
| self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT | |
| ) + torch.diag_embed(self._unbroadcasted_cov_diag) | |
| return covariance_matrix.expand( | |
| self._batch_shape + self._event_shape + self._event_shape | |
| ) | |
| def precision_matrix(self) -> Tensor: | |
| # We use "Woodbury matrix identity" to take advantage of low rank form:: | |
| # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) | |
| # where :math:`C` is the capacitance matrix. | |
| Wt_Dinv = ( | |
| self._unbroadcasted_cov_factor.mT | |
| / self._unbroadcasted_cov_diag.unsqueeze(-2) | |
| ) | |
| A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) | |
| precision_matrix = ( | |
| torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A | |
| ) | |
| return precision_matrix.expand( | |
| self._batch_shape + self._event_shape + self._event_shape | |
| ) | |
| def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: | |
| shape = self._extended_shape(sample_shape) | |
| W_shape = shape[:-1] + self.cov_factor.shape[-1:] | |
| eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) | |
| eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) | |
| return ( | |
| self.loc | |
| + _batch_mv(self._unbroadcasted_cov_factor, eps_W) | |
| + self._unbroadcasted_cov_diag.sqrt() * eps_D | |
| ) | |
| def log_prob(self, value): | |
| if self._validate_args: | |
| self._validate_sample(value) | |
| diff = value - self.loc | |
| M = _batch_lowrank_mahalanobis( | |
| self._unbroadcasted_cov_factor, | |
| self._unbroadcasted_cov_diag, | |
| diff, | |
| self._capacitance_tril, | |
| ) | |
| log_det = _batch_lowrank_logdet( | |
| self._unbroadcasted_cov_factor, | |
| self._unbroadcasted_cov_diag, | |
| self._capacitance_tril, | |
| ) | |
| return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) | |
| def entropy(self): | |
| log_det = _batch_lowrank_logdet( | |
| self._unbroadcasted_cov_factor, | |
| self._unbroadcasted_cov_diag, | |
| self._capacitance_tril, | |
| ) | |
| H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) | |
| if len(self._batch_shape) == 0: | |
| return H | |
| else: | |
| return H.expand(self._batch_shape) | |