import logging from typing import Optional import math import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger("codsworth") class Perplexity: """Perplexity metric for language model evaluation.""" def __init__(self, pad_token_id: Optional[int] = None): self.pad_token_id = pad_token_id self.reset() def reset(self): self.total_loss = 0.0 self.total_tokens = 0 def update(self, loss: float, num_tokens: int): self.total_loss += loss * num_tokens self.total_tokens += num_tokens def compute(self) -> float: if self.total_tokens == 0: return float("inf") avg_loss = self.total_loss / self.total_tokens return math.exp(avg_loss) def __call__(self, loss: float, num_tokens: int = 1) -> float: self.update(loss, num_tokens) return self.compute() def __str__(self) -> str: return f"Perplexity: {self.compute():.4f}" class TokenPerplexity: """Perplexity computed per-token for better granularity.""" def __init__(self, ignore_index: int = -100): self.ignore_index = ignore_index self.reset() def reset(self): self.losses = [] self.num_tokens = 0 def update(self, logits: torch.Tensor, labels: torch.Tensor): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss( ignore_index=self.ignore_index, reduction="none", ) losses = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) valid_tokens = (labels != self.ignore_index).sum().item() self.num_tokens += valid_tokens self.losses.append(losses.mean().item()) def compute(self) -> float: if not self.losses: return float("inf") avg_loss = sum(self.losses) / len(self.losses) return math.exp(avg_loss) def compute_cumulative(self) -> float: if self.num_tokens == 0: return float("inf") avg_loss = sum(self.losses) / len(self.losses) return math.exp(avg_loss) def calculate_perplexity( model: nn.Module, input_ids: torch.Tensor, labels: torch.Tensor, pad_token_id: int = 0, ) -> float: """Calculate perplexity for a single batch.""" with torch.no_grad(): outputs = model(input_ids=input_ids, labels=labels) loss = outputs["loss"] if loss is None: return float("inf") return torch.exp(loss).item() def calculate_batch_perplexity( model: nn.Module, dataloader, pad_token_id: int = 0, ) -> float: """Calculate perplexity for entire dataloader.""" model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"] labels = batch["labels"] outputs = model(input_ids=input_ids, labels=labels) loss = outputs["loss"] valid_tokens = (labels != pad_token_id).sum().item() total_loss += loss.item() * valid_tokens total_tokens += valid_tokens if total_tokens == 0: return float("inf") avg_loss = total_loss / total_tokens return math.exp(avg_loss) class StreamingPerplexity: """Streaming perplexity calculator for large datasets.""" def __init__(self, ignore_index: int = -100): self.ignore_index = ignore_index self.total_log_prob = 0.0 self.total_tokens = 0 def reset(self): self.total_log_prob = 0.0 self.total_tokens = 0 def update(self, logits: torch.Tensor, labels: torch.Tensor): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] log_probs = F.log_softmax(shift_logits, dim=-1) log_prob = log_probs.gather( -1, shift_labels.unsqueeze(-1), ).squeeze(-1) valid_mask = labels != self.ignore_index self.total_log_prob += log_prob[valid_mask].sum().item() self.total_tokens += valid_mask.sum().item() def compute(self) -> float: if self.total_tokens == 0: return float("inf") avg_log_prob = self.total_log_prob / self.total_tokens return math.exp(-avg_log_prob) def compute_perplexity_metrics( model: nn.Module, dataloader, pad_token_id: int = 0, ) -> dict: """Compute comprehensive perplexity metrics.""" model.eval() total_loss = 0.0 total_tokens = 0 total_batches = 0 loss_perplexity = [] with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"] labels = batch["labels"] outputs = model(input_ids=input_ids, labels=labels) loss = outputs["loss"] valid_tokens = (labels != pad_token_id).sum().item() total_loss += loss.item() * valid_tokens total_tokens += valid_tokens total_batches += 1 loss_perplexity.append(torch.exp(loss).item()) avg_loss = total_loss / max(1, total_tokens) metrics = { "perplexity": math.exp(avg_loss), "avg_loss": avg_loss, "total_tokens": total_tokens, "total_batches": total_batches, "mean_batch_perplexity": sum(loss_perplexity) / len(loss_perplexity) if loss_perplexity else float("inf"), "min_batch_perplexity": min(loss_perplexity) if loss_perplexity else float("inf"), "max_batch_perplexity": max(loss_perplexity) if loss_perplexity else float("inf"), } return metrics