| """Perplexity (PPL) evaluator.""" |
|
|
| import math |
| import time |
| from typing import Dict, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from llm_lab.config import EvalConfig |
|
|
|
|
| class PerplexityEvaluator: |
| """Measures Perplexity (PPL). |
| |
| What is Perplexity? |
| PPL = exp(average cross-entropy loss) |
| |
| Intuitive meaning: |
| - PPL = 1: Perfect prediction (impossible) |
| - PPL = 10: Equivalent to picking from 10 candidates each time |
| - PPL = 100: Equivalent to picking from 100 candidates (close to random) |
| - PPL = 32000: Random selection from the entire vocab (initial random model) |
| |
| Good benchmark for a 1B model (English web text): |
| - Trained on 5B tokens: PPL ~30-40 |
| - Trained on 10B tokens: PPL ~20-30 |
| - Trained on 20B tokens: PPL ~15-25 |
| |
| Measurement method: |
| - Compute cross-entropy over all tokens in the validation dataset |
| - Average per token, then apply exp() |
| - Padding tokens are excluded (ignore_index=-100) |
| """ |
|
|
| def __init__(self, config: EvalConfig): |
| self.config = config |
|
|
| @torch.no_grad() |
| def evaluate( |
| self, |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| desc: str = "Evaluation", |
| ) -> Dict[str, float]: |
| """Measures Perplexity. |
| |
| Returns: |
| { |
| "loss": average cross-entropy loss, |
| "perplexity": exp(loss), |
| "num_tokens": total number of tokens used for evaluation, |
| "num_batches": number of batches used for evaluation, |
| } |
| """ |
| model.eval() |
|
|
| total_loss = 0.0 |
| total_tokens = 0 |
| num_batches = 0 |
|
|
| print(f"\nπ {desc}") |
| start_time = time.time() |
|
|
| for i, batch in enumerate(dataloader): |
| if i >= self.config.max_eval_batches: |
| break |
|
|
| input_ids = batch["input_ids"].to(device) |
| targets = batch["targets"].to(device) |
|
|
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| logits, _ = model(input_ids) |
|
|
| |
| |
| |
| loss_per_token = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=-100, |
| reduction="none", |
| ) |
|
|
| |
| valid_mask = (targets.view(-1) != -100) |
| valid_tokens = valid_mask.sum().item() |
|
|
| total_loss += loss_per_token[valid_mask].sum().item() |
| total_tokens += valid_tokens |
| num_batches += 1 |
|
|
| if (i + 1) % 20 == 0: |
| running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20)) |
| print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}") |
|
|
| elapsed = time.time() - start_time |
| avg_loss = total_loss / max(total_tokens, 1) |
| perplexity = math.exp(min(avg_loss, 100)) |
|
|
| results = { |
| "loss": round(avg_loss, 4), |
| "perplexity": round(perplexity, 2), |
| "num_tokens": total_tokens, |
| "num_batches": num_batches, |
| "eval_time_sec": round(elapsed, 1), |
| } |
|
|
| print(f" ββββββββββββββββββββββββββββββββ") |
| print(f" Loss: {results['loss']:.4f}") |
| print(f" Perplexity: {results['perplexity']:.2f}") |
| print(f" Eval tokens: {total_tokens:,}") |
| print(f" Elapsed: {elapsed:.1f}s") |
|
|
| return results |
|
|
| @torch.no_grad() |
| def evaluate_per_position( |
| self, |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| max_batches: int = 50, |
| ) -> List[float]: |
| """Measures loss per position within a sequence. |
| |
| Learning insight: |
| - Positions 0~10: Higher loss (insufficient context) |
| - Positions 100+: Loss stabilizes lower (context is leveraged) |
| - This pattern demonstrates the Transformer's in-context learning capability |
| """ |
| model.eval() |
| seq_len = None |
| position_loss_sum = None |
| position_count = None |
|
|
| for i, batch in enumerate(dataloader): |
| if i >= max_batches: |
| break |
|
|
| input_ids = batch["input_ids"].to(device) |
| targets = batch["targets"].to(device) |
| B, S = targets.shape |
|
|
| if seq_len is None: |
| seq_len = S |
| position_loss_sum = torch.zeros(S, device=device) |
| position_count = torch.zeros(S, device=device) |
|
|
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| logits, _ = model(input_ids) |
|
|
| |
| loss_per_token = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=-100, |
| reduction="none", |
| ).view(B, S) |
|
|
| valid_mask = (targets != -100).float() |
| position_loss_sum += (loss_per_token * valid_mask).sum(dim=0) |
| position_count += valid_mask.sum(dim=0) |
|
|
| |
| position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist() |
| return position_avg_loss |
|
|