codsworth-3.8m / codsworth /eval /perplexity.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
import logging
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger("codsworth")
class Perplexity:
"""Perplexity metric for language model evaluation."""
def __init__(self, pad_token_id: Optional[int] = None):
self.pad_token_id = pad_token_id
self.reset()
def reset(self):
self.total_loss = 0.0
self.total_tokens = 0
def update(self, loss: float, num_tokens: int):
self.total_loss += loss * num_tokens
self.total_tokens += num_tokens
def compute(self) -> float:
if self.total_tokens == 0:
return float("inf")
avg_loss = self.total_loss / self.total_tokens
return math.exp(avg_loss)
def __call__(self, loss: float, num_tokens: int = 1) -> float:
self.update(loss, num_tokens)
return self.compute()
def __str__(self) -> str:
return f"Perplexity: {self.compute():.4f}"
class TokenPerplexity:
"""Perplexity computed per-token for better granularity."""
def __init__(self, ignore_index: int = -100):
self.ignore_index = ignore_index
self.reset()
def reset(self):
self.losses = []
self.num_tokens = 0
def update(self, logits: torch.Tensor, labels: torch.Tensor):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(
ignore_index=self.ignore_index,
reduction="none",
)
losses = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
valid_tokens = (labels != self.ignore_index).sum().item()
self.num_tokens += valid_tokens
self.losses.append(losses.mean().item())
def compute(self) -> float:
if not self.losses:
return float("inf")
avg_loss = sum(self.losses) / len(self.losses)
return math.exp(avg_loss)
def compute_cumulative(self) -> float:
if self.num_tokens == 0:
return float("inf")
avg_loss = sum(self.losses) / len(self.losses)
return math.exp(avg_loss)
def calculate_perplexity(
model: nn.Module,
input_ids: torch.Tensor,
labels: torch.Tensor,
pad_token_id: int = 0,
) -> float:
"""Calculate perplexity for a single batch."""
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs["loss"]
if loss is None:
return float("inf")
return torch.exp(loss).item()
def calculate_batch_perplexity(
model: nn.Module,
dataloader,
pad_token_id: int = 0,
) -> float:
"""Calculate perplexity for entire dataloader."""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"]
labels = batch["labels"]
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs["loss"]
valid_tokens = (labels != pad_token_id).sum().item()
total_loss += loss.item() * valid_tokens
total_tokens += valid_tokens
if total_tokens == 0:
return float("inf")
avg_loss = total_loss / total_tokens
return math.exp(avg_loss)
class StreamingPerplexity:
"""Streaming perplexity calculator for large datasets."""
def __init__(self, ignore_index: int = -100):
self.ignore_index = ignore_index
self.total_log_prob = 0.0
self.total_tokens = 0
def reset(self):
self.total_log_prob = 0.0
self.total_tokens = 0
def update(self, logits: torch.Tensor, labels: torch.Tensor):
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
log_probs = F.log_softmax(shift_logits, dim=-1)
log_prob = log_probs.gather(
-1,
shift_labels.unsqueeze(-1),
).squeeze(-1)
valid_mask = labels != self.ignore_index
self.total_log_prob += log_prob[valid_mask].sum().item()
self.total_tokens += valid_mask.sum().item()
def compute(self) -> float:
if self.total_tokens == 0:
return float("inf")
avg_log_prob = self.total_log_prob / self.total_tokens
return math.exp(-avg_log_prob)
def compute_perplexity_metrics(
model: nn.Module,
dataloader,
pad_token_id: int = 0,
) -> dict:
"""Compute comprehensive perplexity metrics."""
model.eval()
total_loss = 0.0
total_tokens = 0
total_batches = 0
loss_perplexity = []
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"]
labels = batch["labels"]
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs["loss"]
valid_tokens = (labels != pad_token_id).sum().item()
total_loss += loss.item() * valid_tokens
total_tokens += valid_tokens
total_batches += 1
loss_perplexity.append(torch.exp(loss).item())
avg_loss = total_loss / max(1, total_tokens)
metrics = {
"perplexity": math.exp(avg_loss),
"avg_loss": avg_loss,
"total_tokens": total_tokens,
"total_batches": total_batches,
"mean_batch_perplexity": sum(loss_perplexity) / len(loss_perplexity) if loss_perplexity else float("inf"),
"min_batch_perplexity": min(loss_perplexity) if loss_perplexity else float("inf"),
"max_batch_perplexity": max(loss_perplexity) if loss_perplexity else float("inf"),
}
return metrics