| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch.distributions.categorical import Categorical |
| from torchmetrics import Metric |
|
|
| __all__ = ['Perplexity'] |
|
|
|
|
| class Perplexity(Metric): |
| """ |
| This class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around |
| :doc:`torch.distributions.Categorical.perplexity<pytorch:distributions>` method. You have to provide either |
| ``probs`` or ``logits`` to the :meth:`update` method. The class computes perplexities for distributions passed to |
| :meth:`update` method in ``probs`` or ``logits`` arguments and averages the perplexities. Reducing results between |
| all workers is done via SUM operations. |
| See `PyTorch Lightning Metrics <https://pytorch-lightning.readthedocs.io/en/stable/ecosystem/metrics.html>`_ for the metric usage instructions. |
| |
| Args: |
| compute_on_step: |
| Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True`` |
| dist_sync_on_step: |
| Synchronize metric state across processes at each ``forward()`` |
| before returning the value at the step. |
| process_group: |
| Specify the process group on which synchronization is called. default: ``None`` (which selects the entire |
| world) |
| validate_args: |
| If ``True`` values of :meth:`update` method parameters are checked. ``logits`` has to not contain NaNs and |
| ``probs`` last dim has to be valid probability distribution. |
| """ |
|
|
| full_state_update = True |
|
|
| def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True): |
| super().__init__( |
| compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group |
| ) |
| self.validate_args = validate_args |
| self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum') |
| |
| self.add_state('num_distributions', torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum') |
|
|
| def update(self, probs=None, logits=None): |
| """ |
| Updates :attr:`perplexities_sum` and :attr:`num_distributions`. |
| Args: |
| probs: A ``torch.Tensor`` which innermost dimension is valid probability distribution. |
| logits: A ``torch.Tensor`` without NaNs. |
| """ |
| d = Categorical( |
| None if probs is None else probs.detach(), |
| None if logits is None else logits.detach(), |
| validate_args=self.validate_args, |
| ) |
| ppl = d.perplexity() |
| self.num_distributions += ppl.numel() |
| self.perplexities_sum += ppl.sum() |
|
|
| def compute(self): |
| """ |
| Returns perplexity across all workers and resets to 0 :attr:`perplexities_sum` and :attr:`num_distributions`. |
| """ |
| if self.num_distributions.eq(0): |
| return None |
| return self.perplexities_sum / self.num_distributions |
|
|