| import logging |
| from typing import Optional |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| logger = logging.getLogger("codsworth") |
|
|
|
|
| class Perplexity: |
| """Perplexity metric for language model evaluation.""" |
| |
| def __init__(self, pad_token_id: Optional[int] = None): |
| self.pad_token_id = pad_token_id |
| self.reset() |
| |
| def reset(self): |
| self.total_loss = 0.0 |
| self.total_tokens = 0 |
| |
| def update(self, loss: float, num_tokens: int): |
| self.total_loss += loss * num_tokens |
| self.total_tokens += num_tokens |
| |
| def compute(self) -> float: |
| if self.total_tokens == 0: |
| return float("inf") |
| |
| avg_loss = self.total_loss / self.total_tokens |
| return math.exp(avg_loss) |
| |
| def __call__(self, loss: float, num_tokens: int = 1) -> float: |
| self.update(loss, num_tokens) |
| return self.compute() |
| |
| def __str__(self) -> str: |
| return f"Perplexity: {self.compute():.4f}" |
|
|
|
|
| class TokenPerplexity: |
| """Perplexity computed per-token for better granularity.""" |
| |
| def __init__(self, ignore_index: int = -100): |
| self.ignore_index = ignore_index |
| self.reset() |
| |
| def reset(self): |
| self.losses = [] |
| self.num_tokens = 0 |
| |
| def update(self, logits: torch.Tensor, labels: torch.Tensor): |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss( |
| ignore_index=self.ignore_index, |
| reduction="none", |
| ) |
| |
| losses = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ) |
| |
| valid_tokens = (labels != self.ignore_index).sum().item() |
| self.num_tokens += valid_tokens |
| |
| self.losses.append(losses.mean().item()) |
| |
| def compute(self) -> float: |
| if not self.losses: |
| return float("inf") |
| |
| avg_loss = sum(self.losses) / len(self.losses) |
| return math.exp(avg_loss) |
| |
| def compute_cumulative(self) -> float: |
| if self.num_tokens == 0: |
| return float("inf") |
| |
| avg_loss = sum(self.losses) / len(self.losses) |
| return math.exp(avg_loss) |
|
|
|
|
| def calculate_perplexity( |
| model: nn.Module, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor, |
| pad_token_id: int = 0, |
| ) -> float: |
| """Calculate perplexity for a single batch.""" |
| |
| with torch.no_grad(): |
| outputs = model(input_ids=input_ids, labels=labels) |
| loss = outputs["loss"] |
| |
| if loss is None: |
| return float("inf") |
| |
| return torch.exp(loss).item() |
|
|
|
|
| def calculate_batch_perplexity( |
| model: nn.Module, |
| dataloader, |
| pad_token_id: int = 0, |
| ) -> float: |
| """Calculate perplexity for entire dataloader.""" |
| |
| model.eval() |
| total_loss = 0.0 |
| total_tokens = 0 |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch["input_ids"] |
| labels = batch["labels"] |
| |
| outputs = model(input_ids=input_ids, labels=labels) |
| loss = outputs["loss"] |
| |
| valid_tokens = (labels != pad_token_id).sum().item() |
| |
| total_loss += loss.item() * valid_tokens |
| total_tokens += valid_tokens |
| |
| if total_tokens == 0: |
| return float("inf") |
| |
| avg_loss = total_loss / total_tokens |
| return math.exp(avg_loss) |
|
|
|
|
| class StreamingPerplexity: |
| """Streaming perplexity calculator for large datasets.""" |
| |
| def __init__(self, ignore_index: int = -100): |
| self.ignore_index = ignore_index |
| self.total_log_prob = 0.0 |
| self.total_tokens = 0 |
| |
| def reset(self): |
| self.total_log_prob = 0.0 |
| self.total_tokens = 0 |
| |
| def update(self, logits: torch.Tensor, labels: torch.Tensor): |
| shift_logits = logits[..., :-1, :] |
| shift_labels = labels[..., 1:] |
| |
| log_probs = F.log_softmax(shift_logits, dim=-1) |
| |
| log_prob = log_probs.gather( |
| -1, |
| shift_labels.unsqueeze(-1), |
| ).squeeze(-1) |
| |
| valid_mask = labels != self.ignore_index |
| |
| self.total_log_prob += log_prob[valid_mask].sum().item() |
| self.total_tokens += valid_mask.sum().item() |
| |
| def compute(self) -> float: |
| if self.total_tokens == 0: |
| return float("inf") |
| |
| avg_log_prob = self.total_log_prob / self.total_tokens |
| return math.exp(-avg_log_prob) |
|
|
|
|
| def compute_perplexity_metrics( |
| model: nn.Module, |
| dataloader, |
| pad_token_id: int = 0, |
| ) -> dict: |
| """Compute comprehensive perplexity metrics.""" |
| |
| model.eval() |
| |
| total_loss = 0.0 |
| total_tokens = 0 |
| total_batches = 0 |
| |
| loss_perplexity = [] |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch["input_ids"] |
| labels = batch["labels"] |
| |
| outputs = model(input_ids=input_ids, labels=labels) |
| loss = outputs["loss"] |
| |
| valid_tokens = (labels != pad_token_id).sum().item() |
| |
| total_loss += loss.item() * valid_tokens |
| total_tokens += valid_tokens |
| total_batches += 1 |
| |
| loss_perplexity.append(torch.exp(loss).item()) |
| |
| avg_loss = total_loss / max(1, total_tokens) |
| |
| metrics = { |
| "perplexity": math.exp(avg_loss), |
| "avg_loss": avg_loss, |
| "total_tokens": total_tokens, |
| "total_batches": total_batches, |
| "mean_batch_perplexity": sum(loss_perplexity) / len(loss_perplexity) if loss_perplexity else float("inf"), |
| "min_batch_perplexity": min(loss_perplexity) if loss_perplexity else float("inf"), |
| "max_batch_perplexity": max(loss_perplexity) if loss_perplexity else float("inf"), |
| } |
| |
| return metrics |