sage / eval /perplexity.py
sage002's picture
feat: add authenticated remote control UI and ngrok launcher
b4f432f verified
"""Validation perplexity evaluation."""
from __future__ import annotations
import math
import torch
from train.loss import masked_cross_entropy
@torch.no_grad()
def evaluate_perplexity(
model: torch.nn.Module,
dataloader,
device: torch.device,
dtype: torch.dtype | None = None,
max_batches: int = 16,
) -> dict[str, float]:
"""Evaluate average loss and perplexity on a validation loader."""
model.eval()
losses: list[float] = []
for index, batch in enumerate(dataloader):
if index >= max_batches:
break
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
loss_mask = batch["loss_mask"].to(device)
if dtype is not None and device.type != "cpu":
with torch.autocast(device_type=device.type, dtype=dtype):
logits, _ = model(input_ids)
loss = masked_cross_entropy(logits, labels, loss_mask)
else:
logits, _ = model(input_ids)
loss = masked_cross_entropy(logits, labels, loss_mask)
losses.append(float(loss))
model.train()
mean_loss = sum(losses) / max(len(losses), 1)
return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}