File size: 1,253 Bytes
ef18673 b4f432f ef18673 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | """Validation perplexity evaluation."""
from __future__ import annotations
import math
import torch
from train.loss import masked_cross_entropy
@torch.no_grad()
def evaluate_perplexity(
model: torch.nn.Module,
dataloader,
device: torch.device,
dtype: torch.dtype | None = None,
max_batches: int = 16,
) -> dict[str, float]:
"""Evaluate average loss and perplexity on a validation loader."""
model.eval()
losses: list[float] = []
for index, batch in enumerate(dataloader):
if index >= max_batches:
break
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
loss_mask = batch["loss_mask"].to(device)
if dtype is not None and device.type != "cpu":
with torch.autocast(device_type=device.type, dtype=dtype):
logits, _ = model(input_ids)
loss = masked_cross_entropy(logits, labels, loss_mask)
else:
logits, _ = model(input_ids)
loss = masked_cross_entropy(logits, labels, loss_mask)
losses.append(float(loss))
model.train()
mean_loss = sum(losses) / max(len(losses), 1)
return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}
|