NanoGPT-X_Base / eval-loss.py
luxopes's picture
Upload 7 files
8449341 verified
# -- coding: utf-8 --
# Compare validation loss of multiple GPT checkpoints
# Works with old and new checkpoint formats
# Compatible with Antonín Tomeček Transformer code
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# =========================
# CONFIG
# =========================
CHECKPOINTS = {
"pretrain_900k": "checkpoints/step_900000.pt",
"continual_100k": "checkpoints/step_100000.pt",
"continual_200k": "checkpoints/step_200000.pt",
"continual_300k": "checkpoints/step_300000.pt",
"continual_400k": "checkpoints/step_400000.pt",
"continual_500k": "checkpoints/step_500000.pt",
}
TOKENIZER_MODEL_PATH = "tokenizer.model"
VALID_BIN = "valid.bin"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1 # můžeš zvýšit podle VRAM
# =========================
# ModelArgs
# =========================
from dataclasses import dataclass
@dataclass
class ModelArgs:
dim: int = 768
n_layers: int = 12
n_heads: int = 12
n_kv_heads: int = 4
vocab_size: int = 32000
multiple_of: int = 256
ffn_dim_multiplier: float = 3.0
norm_eps: float = 1e-5
max_seq_len: int = 1024
# =========================
# Dataset
# =========================
class MemmapDataset(Dataset):
def __init__(self, path: str, max_seq_len: int, stride=None):
self.tokens = np.memmap(path, dtype=np.int32, mode="r")
self.max_seq_len = max_seq_len
self.stride = stride or max_seq_len // 2
max_start = len(self.tokens) - (max_seq_len + 1)
if max_start <= 0:
raise ValueError("Dataset too small")
self.starts = list(range(0, max_start, self.stride))
if self.starts[-1] != max_start:
self.starts.append(max_start)
def __len__(self):
return len(self.starts)
def __getitem__(self, idx):
i = self.starts[idx]
seq = torch.from_numpy(
self.tokens[i:i + self.max_seq_len + 1].copy()
).long()
return seq[:-1], seq[1:]
# =========================
# Transformer model
# =========================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def precompute_freqs_cis(dim, seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs)
return freqs.cos(), freqs.sin()
def apply_rotary_emb(x, cos, sin):
x1, x2 = x[..., 0::2], x[..., 1::2]
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
out = torch.empty_like(x)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_kv_heads
self.repeat_kv = args.n_heads // args.n_kv_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
def forward(self, x, cos, sin):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
k = k.repeat_interleave(self.repeat_kv, dim=2)
v = v.repeat_interleave(self.repeat_kv, dim=2)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1,2).contiguous().view(B, T, -1)
return self.wo(out)
class FeedForward(nn.Module):
def __init__(self, dim, multiple_of, mult):
super().__init__()
hidden = multiple_of * ((int(dim * mult) + multiple_of -1)//multiple_of)
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.w3 = nn.Linear(dim, hidden, bias=False)
def forward(self,x):
return self.w2(F.silu(self.w1(x))*self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.attn = Attention(args)
self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier)
self.attn_norm = RMSNorm(args.dim, args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
def forward(self, x, cos, sin):
x = x + self.attn(self.attn_norm(x), cos, sin)
x = x + self.ffn(self.ffn_norm(x))
return x
class Transformer(nn.Module):
def __init__(self, args):
super().__init__()
self.tok_emb = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
self.norm = RMSNorm(args.dim, args.norm_eps)
self.out = nn.Linear(args.dim, args.vocab_size, bias=False)
cos, sin = precompute_freqs_cis(args.dim//args.n_heads, args.max_seq_len*2)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(self, tokens):
B, T = tokens.shape
h = self.tok_emb(tokens)
cos = self.cos_cached[:T]
sin = self.sin_cached[:T]
for layer in self.layers:
h = layer(h, cos, sin)
h = self.norm(h)
return self.out(h)
# =========================
# Eval function
# =========================
def evaluate_checkpoint(path, valid_loader, tokenizer, args):
ckpt = torch.load(path, map_location="cpu", weights_only=False)
# Podpora starého i nového formátu checkpointu
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
state_dict = ckpt
model = Transformer(args)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for x, y in valid_loader:
x = x.to(DEVICE)
y = y.to(DEVICE)
logits = model(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=tokenizer.pad_id(),
reduction="sum",
)
total_loss += loss.item()
total_tokens += (y != tokenizer.pad_id()).sum().item()
return total_loss / total_tokens
# =========================
# MAIN
# =========================
def main():
# pevné ModelArgs
args = ModelArgs()
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
args.vocab_size = tokenizer.vocab_size()
# dataset
valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print("="*70)
print("Checkpoint comparison (validation)")
print("="*70)
results = {}
for name, path in CHECKPOINTS.items():
print(f"[Eval] {name}")
loss = evaluate_checkpoint(path, valid_loader, tokenizer, args)
ppl = math.exp(loss)
results[name] = (loss, ppl)
print(f" Val loss: {loss:.6f}")
print(f" Perplexity: {ppl:.2f}")
print("-"*50)
print("\nSummary:")
for name, (loss, ppl) in results.items():
print(f"{name:20s} | loss {loss:.6f} | ppl {ppl:.2f}")
print("="*70)
if __name__ == "__main__":
main()