File size: 3,267 Bytes
7e5763c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
WikiText-2 scorer for İvme — reports cross-entropy loss and perplexity.

The Tiny-ML leaderboard's "WikiText-2 ↓" column reports per-token cross-entropy
loss (e.g. competitors at 2.66, 3.08), NOT perplexity. We print both so you can
match whichever the leaderboard uses.

Method: concatenate the WikiText-2 test split, tokenize, and score with a
sliding window of the model's context length, summing log-probs over all
predicted tokens. CE loss = -mean(log p(token)). Perplexity = exp(CE loss).

Usage:
    python eval_wikitext.py --checkpoint checkpoints/ivme_base_ema.pt
"""

from __future__ import annotations
import argparse
import json
import math
import sys
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from datasets import load_dataset

sys.path.insert(0, ".")
from model import IvmeConfig, IvmeConversate

TOKENIZER_PATH = "ivme_tokenizer.json"


@torch.no_grad()
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", required=True)
    ap.add_argument("--output", default="wikitext_results.json")
    ap.add_argument("--stride", type=int, default=None,
                    help="sliding-window stride; defaults to full context (non-overlapping)")
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = Tokenizer.from_file(TOKENIZER_PATH)

    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    cfg = ckpt["cfg"]
    cfg.attn_backend = "sdpa"
    ctx = cfg.max_seq_len
    model = IvmeConversate(cfg).to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()
    print(f"[wikitext] model loaded: {model.num_params()/1e6:.1f}M on {device}")

    print("[wikitext] loading WikiText-2 test split...")
    ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
    text = "\n\n".join(t for t in ds["text"] if t.strip())
    ids = tok.encode(text).ids
    print(f"[wikitext] {len(ids):,} tokens")

    stride = args.stride or ctx
    total_nll = 0.0
    total_tokens = 0

    for start in range(0, len(ids) - 1, stride):
        chunk = ids[start : start + ctx + 1]
        if len(chunk) < 2:
            break
        inp = torch.tensor([chunk[:-1]], dtype=torch.long, device=device)
        tgt = torch.tensor([chunk[1:]], dtype=torch.long, device=device)
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
                            enabled=device.type == "cuda"):
            logits, _ = model(inp)
        logp = F.log_softmax(logits.float(), dim=-1)
        tok_lp = logp[0, range(tgt.shape[1]), tgt[0]]
        total_nll += -tok_lp.sum().item()
        total_tokens += tgt.shape[1]

    ce_loss = total_nll / total_tokens
    ppl = math.exp(ce_loss)
    print(f"\n{'='*52}")
    print(f"  WikiText-2 cross-entropy loss : {ce_loss:.4f}")
    print(f"  WikiText-2 perplexity         : {ppl:.2f}")
    print(f"{'='*52}")
    print(f"  (leaderboard column reports CE loss, lower is better)")

    with open(args.output, "w") as f:
        json.dump({"wikitext2_ce_loss": ce_loss, "wikitext2_ppl": ppl,
                   "tokens": total_tokens}, f, indent=2)
    print(f"\n[wikitext] saved -> {args.output}")


if __name__ == "__main__":
    main()