"""Perplexity (PPL) evaluator.""" import math import time from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from llm_lab.config import EvalConfig class PerplexityEvaluator: """Measures Perplexity (PPL). What is Perplexity? PPL = exp(average cross-entropy loss) Intuitive meaning: - PPL = 1: Perfect prediction (impossible) - PPL = 10: Equivalent to picking from 10 candidates each time - PPL = 100: Equivalent to picking from 100 candidates (close to random) - PPL = 32000: Random selection from the entire vocab (initial random model) Good benchmark for a 1B model (English web text): - Trained on 5B tokens: PPL ~30-40 - Trained on 10B tokens: PPL ~20-30 - Trained on 20B tokens: PPL ~15-25 Measurement method: - Compute cross-entropy over all tokens in the validation dataset - Average per token, then apply exp() - Padding tokens are excluded (ignore_index=-100) """ def __init__(self, config: EvalConfig): self.config = config @torch.no_grad() def evaluate( self, model: nn.Module, dataloader: DataLoader, device: torch.device, dtype: torch.dtype = torch.bfloat16, desc: str = "Evaluation", ) -> Dict[str, float]: """Measures Perplexity. Returns: { "loss": average cross-entropy loss, "perplexity": exp(loss), "num_tokens": total number of tokens used for evaluation, "num_batches": number of batches used for evaluation, } """ model.eval() total_loss = 0.0 total_tokens = 0 num_batches = 0 print(f"\nšŸ“Š {desc}") start_time = time.time() for i, batch in enumerate(dataloader): if i >= self.config.max_eval_batches: break input_ids = batch["input_ids"].to(device) targets = batch["targets"].to(device) with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): logits, _ = model(input_ids) # Per-token cross-entropy (reduction='none') # logits: (B, S, V) → (B*S, V) # targets: (B, S) → (B*S,) loss_per_token = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100, reduction="none", ) # Count only valid tokens that are not -100 valid_mask = (targets.view(-1) != -100) valid_tokens = valid_mask.sum().item() total_loss += loss_per_token[valid_mask].sum().item() total_tokens += valid_tokens num_batches += 1 if (i + 1) % 20 == 0: running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20)) print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}") elapsed = time.time() - start_time avg_loss = total_loss / max(total_tokens, 1) perplexity = math.exp(min(avg_loss, 100)) # prevent overflow results = { "loss": round(avg_loss, 4), "perplexity": round(perplexity, 2), "num_tokens": total_tokens, "num_batches": num_batches, "eval_time_sec": round(elapsed, 1), } print(f" ────────────────────────────────") print(f" Loss: {results['loss']:.4f}") print(f" Perplexity: {results['perplexity']:.2f}") print(f" Eval tokens: {total_tokens:,}") print(f" Elapsed: {elapsed:.1f}s") return results @torch.no_grad() def evaluate_per_position( self, model: nn.Module, dataloader: DataLoader, device: torch.device, dtype: torch.dtype = torch.bfloat16, max_batches: int = 50, ) -> List[float]: """Measures loss per position within a sequence. Learning insight: - Positions 0~10: Higher loss (insufficient context) - Positions 100+: Loss stabilizes lower (context is leveraged) - This pattern demonstrates the Transformer's in-context learning capability """ model.eval() seq_len = None position_loss_sum = None position_count = None for i, batch in enumerate(dataloader): if i >= max_batches: break input_ids = batch["input_ids"].to(device) targets = batch["targets"].to(device) B, S = targets.shape if seq_len is None: seq_len = S position_loss_sum = torch.zeros(S, device=device) position_count = torch.zeros(S, device=device) with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): logits, _ = model(input_ids) # Per-token loss in shape (B, S) loss_per_token = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100, reduction="none", ).view(B, S) valid_mask = (targets != -100).float() position_loss_sum += (loss_per_token * valid_mask).sum(dim=0) position_count += valid_mask.sum(dim=0) # Average loss per position position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist() return position_avg_loss