File size: 1,253 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f432f
ef18673
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Validation perplexity evaluation."""

from __future__ import annotations

import math

import torch

from train.loss import masked_cross_entropy


@torch.no_grad()
def evaluate_perplexity(
    model: torch.nn.Module,
    dataloader,
    device: torch.device,
    dtype: torch.dtype | None = None,
    max_batches: int = 16,
) -> dict[str, float]:
    """Evaluate average loss and perplexity on a validation loader."""
    model.eval()
    losses: list[float] = []
    for index, batch in enumerate(dataloader):
        if index >= max_batches:
            break
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        loss_mask = batch["loss_mask"].to(device)
        if dtype is not None and device.type != "cpu":
            with torch.autocast(device_type=device.type, dtype=dtype):
                logits, _ = model(input_ids)
                loss = masked_cross_entropy(logits, labels, loss_mask)
        else:
            logits, _ = model(input_ids)
            loss = masked_cross_entropy(logits, labels, loss_mask)
        losses.append(float(loss))
    model.train()
    mean_loss = sum(losses) / max(len(losses), 1)
    return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}