ereniko commited on
Commit
7e5763c
·
verified ·
1 Parent(s): 337273e

Upload eval_wikitext.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_wikitext.py +92 -0
eval_wikitext.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WikiText-2 scorer for İvme — reports cross-entropy loss and perplexity.
3
+
4
+ The Tiny-ML leaderboard's "WikiText-2 ↓" column reports per-token cross-entropy
5
+ loss (e.g. competitors at 2.66, 3.08), NOT perplexity. We print both so you can
6
+ match whichever the leaderboard uses.
7
+
8
+ Method: concatenate the WikiText-2 test split, tokenize, and score with a
9
+ sliding window of the model's context length, summing log-probs over all
10
+ predicted tokens. CE loss = -mean(log p(token)). Perplexity = exp(CE loss).
11
+
12
+ Usage:
13
+ python eval_wikitext.py --checkpoint checkpoints/ivme_base_ema.pt
14
+ """
15
+
16
+ from __future__ import annotations
17
+ import argparse
18
+ import json
19
+ import math
20
+ import sys
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from tokenizers import Tokenizer
24
+ from datasets import load_dataset
25
+
26
+ sys.path.insert(0, ".")
27
+ from model import IvmeConfig, IvmeConversate
28
+
29
+ TOKENIZER_PATH = "ivme_tokenizer.json"
30
+
31
+
32
+ @torch.no_grad()
33
+ def main():
34
+ ap = argparse.ArgumentParser()
35
+ ap.add_argument("--checkpoint", required=True)
36
+ ap.add_argument("--output", default="wikitext_results.json")
37
+ ap.add_argument("--stride", type=int, default=None,
38
+ help="sliding-window stride; defaults to full context (non-overlapping)")
39
+ args = ap.parse_args()
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ tok = Tokenizer.from_file(TOKENIZER_PATH)
43
+
44
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
45
+ cfg = ckpt["cfg"]
46
+ cfg.attn_backend = "sdpa"
47
+ ctx = cfg.max_seq_len
48
+ model = IvmeConversate(cfg).to(device)
49
+ model.load_state_dict(ckpt["model"])
50
+ model.eval()
51
+ print(f"[wikitext] model loaded: {model.num_params()/1e6:.1f}M on {device}")
52
+
53
+ print("[wikitext] loading WikiText-2 test split...")
54
+ ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
55
+ text = "\n\n".join(t for t in ds["text"] if t.strip())
56
+ ids = tok.encode(text).ids
57
+ print(f"[wikitext] {len(ids):,} tokens")
58
+
59
+ stride = args.stride or ctx
60
+ total_nll = 0.0
61
+ total_tokens = 0
62
+
63
+ for start in range(0, len(ids) - 1, stride):
64
+ chunk = ids[start : start + ctx + 1]
65
+ if len(chunk) < 2:
66
+ break
67
+ inp = torch.tensor([chunk[:-1]], dtype=torch.long, device=device)
68
+ tgt = torch.tensor([chunk[1:]], dtype=torch.long, device=device)
69
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
70
+ enabled=device.type == "cuda"):
71
+ logits, _ = model(inp)
72
+ logp = F.log_softmax(logits.float(), dim=-1)
73
+ tok_lp = logp[0, range(tgt.shape[1]), tgt[0]]
74
+ total_nll += -tok_lp.sum().item()
75
+ total_tokens += tgt.shape[1]
76
+
77
+ ce_loss = total_nll / total_tokens
78
+ ppl = math.exp(ce_loss)
79
+ print(f"\n{'='*52}")
80
+ print(f" WikiText-2 cross-entropy loss : {ce_loss:.4f}")
81
+ print(f" WikiText-2 perplexity : {ppl:.2f}")
82
+ print(f"{'='*52}")
83
+ print(f" (leaderboard column reports CE loss, lower is better)")
84
+
85
+ with open(args.output, "w") as f:
86
+ json.dump({"wikitext2_ce_loss": ce_loss, "wikitext2_ppl": ppl,
87
+ "tokens": total_tokens}, f, indent=2)
88
+ print(f"\n[wikitext] saved -> {args.output}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()