adityashisharma commited on
Commit
a3b8ab7
·
verified ·
1 Parent(s): f2a9b8c

Create eval/eval_perplexity.py

Browse files
Files changed (1) hide show
  1. eval/eval_perplexity.py +20 -0
eval/eval_perplexity.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, json, torch
2
+ from tokenizers import Tokenizer
3
+ from model.tiny_gpt2 import TinyGPT2, GPTConfig
4
+
5
+ tok = Tokenizer.from_file("out/tokenizer.json")
6
+ cfg = GPTConfig(**json.load(open("out/pretrain/gpt_config.json")))
7
+ m = TinyGPT2(cfg); m.load_state_dict(torch.load("out/sft/model_sft.pt", map_location="cpu")); m.eval()
8
+
9
+ val_text = open("data/corpus_raw.txt","r",encoding="utf-8").read()[:20000]
10
+ ids = tok.encode(val_text).ids
11
+ losses = []
12
+ with torch.no_grad():
13
+ for i in range(0, len(ids)-cfg.block_size, cfg.block_size):
14
+ x = torch.tensor([ids[i:i+cfg.block_size-1]], dtype=torch.long)
15
+ y = torch.tensor([ids[i+1:i+cfg.block_size]], dtype=torch.long)
16
+ logits = m(x)
17
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
18
+ losses.append(loss.item())
19
+ ppl = math.exp(sum(losses)/len(losses)) if losses else float('inf')
20
+ print(f"Perplexity ~ {ppl:.2f}")