Spaces:
Sleeping
Sleeping
| # Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py | |
| # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) | |
| # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py | |
| # But we pass in the loss to avoid recomputation | |
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torchmetrics import Metric | |
| try: | |
| from flash_attn.losses.cross_entropy import CrossEntropyLoss | |
| except ImportError: | |
| CrossEntropyLoss = torch.nn.CrossEntropyLoss | |
| __all__ = ['Perplexity'] | |
| class Perplexity(Metric): | |
| r""" | |
| Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits | |
| per word a model needs to represent the sample. | |
| Args: | |
| kwargs: | |
| Additional keyword arguments, see :ref:`Metric kwargs` for more info. | |
| Examples: | |
| >>> import torch | |
| >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) | |
| >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) | |
| >>> target[0, 6:] = -100 | |
| >>> metric = Perplexity(ignore_index=-100) | |
| >>> metric(preds, target) | |
| tensor(5.2545) | |
| """ | |
| is_differentiable = True | |
| higher_is_better = False | |
| full_state_update = False | |
| total_log_probs: Tensor | |
| count: Tensor | |
| def __init__(self, **kwargs: Dict[str, Any]): | |
| super().__init__(**kwargs) | |
| self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), | |
| dist_reduce_fx="sum") | |
| self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") | |
| self.loss_fn = CrossEntropyLoss() | |
| def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore | |
| """Compute and store intermediate statistics for Perplexity. | |
| Args: | |
| preds: | |
| Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. | |
| target: | |
| Ground truth values with a shape [batch_size, seq_len]. | |
| """ | |
| count = target.numel() | |
| if loss is None: | |
| loss = self.loss_fn(preds, target) | |
| self.total_log_probs += loss.double() * count | |
| self.count += count | |
| def compute(self) -> Tensor: | |
| """Compute the Perplexity. | |
| Returns: | |
| Perplexity | |
| """ | |
| return torch.exp(self.total_log_probs / self.count) | |