"""Validation perplexity evaluation.""" from __future__ import annotations import math import torch from train.loss import masked_cross_entropy @torch.no_grad() def evaluate_perplexity( model: torch.nn.Module, dataloader, device: torch.device, dtype: torch.dtype | None = None, max_batches: int = 16, ) -> dict[str, float]: """Evaluate average loss and perplexity on a validation loader.""" model.eval() losses: list[float] = [] for index, batch in enumerate(dataloader): if index >= max_batches: break input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) loss_mask = batch["loss_mask"].to(device) if dtype is not None and device.type != "cpu": with torch.autocast(device_type=device.type, dtype=dtype): logits, _ = model(input_ids) loss = masked_cross_entropy(logits, labels, loss_mask) else: logits, _ = model(input_ids) loss = masked_cross_entropy(logits, labels, loss_mask) losses.append(float(loss)) model.train() mean_loss = sum(losses) / max(len(losses), 1) return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}