LLM-1B-Lab / llm_lab /evaluation /perplexity.py
Vjeong's picture
docs: translate all Korean comments and docstrings to English
858e8b2
"""Perplexity (PPL) evaluator."""
import math
import time
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from llm_lab.config import EvalConfig
class PerplexityEvaluator:
"""Measures Perplexity (PPL).
What is Perplexity?
PPL = exp(average cross-entropy loss)
Intuitive meaning:
- PPL = 1: Perfect prediction (impossible)
- PPL = 10: Equivalent to picking from 10 candidates each time
- PPL = 100: Equivalent to picking from 100 candidates (close to random)
- PPL = 32000: Random selection from the entire vocab (initial random model)
Good benchmark for a 1B model (English web text):
- Trained on 5B tokens: PPL ~30-40
- Trained on 10B tokens: PPL ~20-30
- Trained on 20B tokens: PPL ~15-25
Measurement method:
- Compute cross-entropy over all tokens in the validation dataset
- Average per token, then apply exp()
- Padding tokens are excluded (ignore_index=-100)
"""
def __init__(self, config: EvalConfig):
self.config = config
@torch.no_grad()
def evaluate(
self,
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
desc: str = "Evaluation",
) -> Dict[str, float]:
"""Measures Perplexity.
Returns:
{
"loss": average cross-entropy loss,
"perplexity": exp(loss),
"num_tokens": total number of tokens used for evaluation,
"num_batches": number of batches used for evaluation,
}
"""
model.eval()
total_loss = 0.0
total_tokens = 0
num_batches = 0
print(f"\nπŸ“Š {desc}")
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= self.config.max_eval_batches:
break
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
logits, _ = model(input_ids)
# Per-token cross-entropy (reduction='none')
# logits: (B, S, V) β†’ (B*S, V)
# targets: (B, S) β†’ (B*S,)
loss_per_token = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
reduction="none",
)
# Count only valid tokens that are not -100
valid_mask = (targets.view(-1) != -100)
valid_tokens = valid_mask.sum().item()
total_loss += loss_per_token[valid_mask].sum().item()
total_tokens += valid_tokens
num_batches += 1
if (i + 1) % 20 == 0:
running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
elapsed = time.time() - start_time
avg_loss = total_loss / max(total_tokens, 1)
perplexity = math.exp(min(avg_loss, 100)) # prevent overflow
results = {
"loss": round(avg_loss, 4),
"perplexity": round(perplexity, 2),
"num_tokens": total_tokens,
"num_batches": num_batches,
"eval_time_sec": round(elapsed, 1),
}
print(f" ────────────────────────────────")
print(f" Loss: {results['loss']:.4f}")
print(f" Perplexity: {results['perplexity']:.2f}")
print(f" Eval tokens: {total_tokens:,}")
print(f" Elapsed: {elapsed:.1f}s")
return results
@torch.no_grad()
def evaluate_per_position(
self,
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
max_batches: int = 50,
) -> List[float]:
"""Measures loss per position within a sequence.
Learning insight:
- Positions 0~10: Higher loss (insufficient context)
- Positions 100+: Loss stabilizes lower (context is leveraged)
- This pattern demonstrates the Transformer's in-context learning capability
"""
model.eval()
seq_len = None
position_loss_sum = None
position_count = None
for i, batch in enumerate(dataloader):
if i >= max_batches:
break
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
B, S = targets.shape
if seq_len is None:
seq_len = S
position_loss_sum = torch.zeros(S, device=device)
position_count = torch.zeros(S, device=device)
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
logits, _ = model(input_ids)
# Per-token loss in shape (B, S)
loss_per_token = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
reduction="none",
).view(B, S)
valid_mask = (targets != -100).float()
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
position_count += valid_mask.sum(dim=0)
# Average loss per position
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
return position_avg_loss