File size: 5,788 Bytes
858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """Perplexity (PPL) evaluator."""
import math
import time
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from llm_lab.config import EvalConfig
class PerplexityEvaluator:
"""Measures Perplexity (PPL).
What is Perplexity?
PPL = exp(average cross-entropy loss)
Intuitive meaning:
- PPL = 1: Perfect prediction (impossible)
- PPL = 10: Equivalent to picking from 10 candidates each time
- PPL = 100: Equivalent to picking from 100 candidates (close to random)
- PPL = 32000: Random selection from the entire vocab (initial random model)
Good benchmark for a 1B model (English web text):
- Trained on 5B tokens: PPL ~30-40
- Trained on 10B tokens: PPL ~20-30
- Trained on 20B tokens: PPL ~15-25
Measurement method:
- Compute cross-entropy over all tokens in the validation dataset
- Average per token, then apply exp()
- Padding tokens are excluded (ignore_index=-100)
"""
def __init__(self, config: EvalConfig):
self.config = config
@torch.no_grad()
def evaluate(
self,
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
desc: str = "Evaluation",
) -> Dict[str, float]:
"""Measures Perplexity.
Returns:
{
"loss": average cross-entropy loss,
"perplexity": exp(loss),
"num_tokens": total number of tokens used for evaluation,
"num_batches": number of batches used for evaluation,
}
"""
model.eval()
total_loss = 0.0
total_tokens = 0
num_batches = 0
print(f"\nπ {desc}")
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= self.config.max_eval_batches:
break
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
logits, _ = model(input_ids)
# Per-token cross-entropy (reduction='none')
# logits: (B, S, V) β (B*S, V)
# targets: (B, S) β (B*S,)
loss_per_token = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
reduction="none",
)
# Count only valid tokens that are not -100
valid_mask = (targets.view(-1) != -100)
valid_tokens = valid_mask.sum().item()
total_loss += loss_per_token[valid_mask].sum().item()
total_tokens += valid_tokens
num_batches += 1
if (i + 1) % 20 == 0:
running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
elapsed = time.time() - start_time
avg_loss = total_loss / max(total_tokens, 1)
perplexity = math.exp(min(avg_loss, 100)) # prevent overflow
results = {
"loss": round(avg_loss, 4),
"perplexity": round(perplexity, 2),
"num_tokens": total_tokens,
"num_batches": num_batches,
"eval_time_sec": round(elapsed, 1),
}
print(f" ββββββββββββββββββββββββββββββββ")
print(f" Loss: {results['loss']:.4f}")
print(f" Perplexity: {results['perplexity']:.2f}")
print(f" Eval tokens: {total_tokens:,}")
print(f" Elapsed: {elapsed:.1f}s")
return results
@torch.no_grad()
def evaluate_per_position(
self,
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
max_batches: int = 50,
) -> List[float]:
"""Measures loss per position within a sequence.
Learning insight:
- Positions 0~10: Higher loss (insufficient context)
- Positions 100+: Loss stabilizes lower (context is leveraged)
- This pattern demonstrates the Transformer's in-context learning capability
"""
model.eval()
seq_len = None
position_loss_sum = None
position_count = None
for i, batch in enumerate(dataloader):
if i >= max_batches:
break
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
B, S = targets.shape
if seq_len is None:
seq_len = S
position_loss_sum = torch.zeros(S, device=device)
position_count = torch.zeros(S, device=device)
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
logits, _ = model(input_ids)
# Per-token loss in shape (B, S)
loss_per_token = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
reduction="none",
).view(B, S)
valid_mask = (targets != -100).float()
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
position_count += valid_mask.sum(dim=0)
# Average loss per position
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
return position_avg_loss
|