# -- coding: utf-8 -- # Compare validation loss of multiple GPT checkpoints # Works with old and new checkpoint formats # Compatible with Antonín Tomeček Transformer code import math import torch import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm import numpy as np from torch.utils.data import Dataset, DataLoader from tqdm import tqdm # ========================= # CONFIG # ========================= CHECKPOINTS = { "pretrain_900k": "checkpoints/step_900000.pt", "continual_100k": "checkpoints/step_100000.pt", "continual_200k": "checkpoints/step_200000.pt", "continual_300k": "checkpoints/step_300000.pt", "continual_400k": "checkpoints/step_400000.pt", "continual_500k": "checkpoints/step_500000.pt", } TOKENIZER_MODEL_PATH = "tokenizer.model" VALID_BIN = "valid.bin" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BATCH_SIZE = 1 # můžeš zvýšit podle VRAM # ========================= # ModelArgs # ========================= from dataclasses import dataclass @dataclass class ModelArgs: dim: int = 768 n_layers: int = 12 n_heads: int = 12 n_kv_heads: int = 4 vocab_size: int = 32000 multiple_of: int = 256 ffn_dim_multiplier: float = 3.0 norm_eps: float = 1e-5 max_seq_len: int = 1024 # ========================= # Dataset # ========================= class MemmapDataset(Dataset): def __init__(self, path: str, max_seq_len: int, stride=None): self.tokens = np.memmap(path, dtype=np.int32, mode="r") self.max_seq_len = max_seq_len self.stride = stride or max_seq_len // 2 max_start = len(self.tokens) - (max_seq_len + 1) if max_start <= 0: raise ValueError("Dataset too small") self.starts = list(range(0, max_start, self.stride)) if self.starts[-1] != max_start: self.starts.append(max_start) def __len__(self): return len(self.starts) def __getitem__(self, idx): i = self.starts[idx] seq = torch.from_numpy( self.tokens[i:i + self.max_seq_len + 1].copy() ).long() return seq[:-1], seq[1:] # ========================= # Transformer model # ========================= class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def precompute_freqs_cis(dim, seq_len, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(seq_len) freqs = torch.outer(t, freqs) return freqs.cos(), freqs.sin() def apply_rotary_emb(x, cos, sin): x1, x2 = x[..., 0::2], x[..., 1::2] cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) out = torch.empty_like(x) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos return out class Attention(nn.Module): def __init__(self, args): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads self.n_kv_heads = args.n_kv_heads self.repeat_kv = args.n_heads // args.n_kv_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) def forward(self, x, cos, sin): B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) k = k.repeat_interleave(self.repeat_kv, dim=2) v = v.repeat_interleave(self.repeat_kv, dim=2) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q.transpose(1,2) k = k.transpose(1,2) v = v.transpose(1,2) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = out.transpose(1,2).contiguous().view(B, T, -1) return self.wo(out) class FeedForward(nn.Module): def __init__(self, dim, multiple_of, mult): super().__init__() hidden = multiple_of * ((int(dim * mult) + multiple_of -1)//multiple_of) self.w1 = nn.Linear(dim, hidden, bias=False) self.w2 = nn.Linear(hidden, dim, bias=False) self.w3 = nn.Linear(dim, hidden, bias=False) def forward(self,x): return self.w2(F.silu(self.w1(x))*self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, args): super().__init__() self.attn = Attention(args) self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier) self.attn_norm = RMSNorm(args.dim, args.norm_eps) self.ffn_norm = RMSNorm(args.dim, args.norm_eps) def forward(self, x, cos, sin): x = x + self.attn(self.attn_norm(x), cos, sin) x = x + self.ffn(self.ffn_norm(x)) return x class Transformer(nn.Module): def __init__(self, args): super().__init__() self.tok_emb = nn.Embedding(args.vocab_size, args.dim) self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)]) self.norm = RMSNorm(args.dim, args.norm_eps) self.out = nn.Linear(args.dim, args.vocab_size, bias=False) cos, sin = precompute_freqs_cis(args.dim//args.n_heads, args.max_seq_len*2) self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) def forward(self, tokens): B, T = tokens.shape h = self.tok_emb(tokens) cos = self.cos_cached[:T] sin = self.sin_cached[:T] for layer in self.layers: h = layer(h, cos, sin) h = self.norm(h) return self.out(h) # ========================= # Eval function # ========================= def evaluate_checkpoint(path, valid_loader, tokenizer, args): ckpt = torch.load(path, map_location="cpu", weights_only=False) # Podpora starého i nového formátu checkpointu if isinstance(ckpt, dict) and "model_state_dict" in ckpt: state_dict = ckpt["model_state_dict"] else: state_dict = ckpt model = Transformer(args) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for x, y in valid_loader: x = x.to(DEVICE) y = y.to(DEVICE) logits = model(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=tokenizer.pad_id(), reduction="sum", ) total_loss += loss.item() total_tokens += (y != tokenizer.pad_id()).sum().item() return total_loss / total_tokens # ========================= # MAIN # ========================= def main(): # pevné ModelArgs args = ModelArgs() tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) args.vocab_size = tokenizer.vocab_size() # dataset valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len) valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) print("="*70) print("Checkpoint comparison (validation)") print("="*70) results = {} for name, path in CHECKPOINTS.items(): print(f"[Eval] {name}") loss = evaluate_checkpoint(path, valid_loader, tokenizer, args) ppl = math.exp(loss) results[name] = (loss, ppl) print(f" Val loss: {loss:.6f}") print(f" Perplexity: {ppl:.2f}") print("-"*50) print("\nSummary:") for name, (loss, ppl) in results.items(): print(f"{name:20s} | loss {loss:.6f} | ppl {ppl:.2f}") print("="*70) if __name__ == "__main__": main()