ereniko commited on
Commit
44217ec
·
verified ·
1 Parent(s): e82a88e

Upload eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval.py +232 -0
eval.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eval harness for İvme-Conversate.
3
+
4
+ Wraps the custom model + tokenizer in an lm-eval compatible interface and runs
5
+ HellaSwag and ARC-Easy — the two benchmarks scored on the Tiny-ML leaderboard.
6
+
7
+ Usage:
8
+ python eval.py --checkpoint checkpoints/ivme_base_ema.pt
9
+ python eval.py --checkpoint checkpoints/ivme_base_ema.pt --tasks hellaswag,arc_easy
10
+ python eval.py --checkpoint checkpoints/ivme_base_ema.pt --tasks hellaswag,arc_easy,piqa
11
+
12
+ Requirements:
13
+ pip install lm-eval tokenizers torch
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import sys
21
+ import torch
22
+ import numpy as np
23
+ from tokenizers import Tokenizer
24
+
25
+ # lm-eval imports
26
+ from lm_eval.api.model import LM
27
+ from lm_eval.api.instance import Instance
28
+ import lm_eval
29
+
30
+ # Local
31
+ sys.path.insert(0, ".")
32
+ from model import IvmeConfig, IvmeConversate
33
+
34
+ TOKENIZER_PATH = "ivme_tokenizer.json"
35
+ DEFAULT_TASKS = "hellaswag,arc_easy"
36
+
37
+
38
+ # --------------------------------------------------------------------------- #
39
+ # lm-eval wrapper
40
+ # --------------------------------------------------------------------------- #
41
+ class IvmeLM(LM):
42
+ def __init__(self, checkpoint_path: str, device: str = "cuda", batch_size: int = 32):
43
+ super().__init__()
44
+ self._device = torch.device(device if torch.cuda.is_available() else "cpu")
45
+ self._batch_size = batch_size
46
+
47
+ # Load tokenizer
48
+ print(f"[eval] loading tokenizer from {TOKENIZER_PATH}")
49
+ self._tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
50
+ self._tokenizer.no_truncation()
51
+ self._tokenizer.no_padding()
52
+ self.vocab_size = self._tokenizer.get_vocab_size()
53
+ self.eos_token_id = self._tokenizer.token_to_id("<|eos|>")
54
+
55
+ # Load model
56
+ print(f"[eval] loading model from {checkpoint_path}")
57
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
58
+ cfg = ckpt["cfg"]
59
+ # Force SDPA for eval — no training kernels needed, wider compatibility
60
+ cfg.attn_backend = "sdpa"
61
+ self._model = IvmeConversate(cfg)
62
+ self._model.load_state_dict(ckpt["model"])
63
+ self._model.to(self._device)
64
+ self._model.eval()
65
+ n = self._model.num_params()
66
+ print(f"[eval] model loaded: {n/1e6:.1f}M params on {self._device}")
67
+
68
+ @property
69
+ def max_length(self):
70
+ return self._model.cfg.max_seq_len
71
+
72
+ @property
73
+ def max_gen_toks(self):
74
+ return 256
75
+
76
+ def tok_encode(self, text: str) -> list[int]:
77
+ return self._tokenizer.encode(text).ids
78
+
79
+ def tok_decode(self, tokens: list[int]) -> str:
80
+ return self._tokenizer.decode(tokens)
81
+
82
+ # ---- Required lm-eval interface methods -------------------------------- #
83
+
84
+ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
85
+ """Compute log-likelihood of each (context, continuation) pair."""
86
+ results = []
87
+ for i in range(0, len(requests), self._batch_size):
88
+ batch = requests[i : i + self._batch_size]
89
+ results.extend(self._loglikelihood_batch(batch))
90
+ return results
91
+
92
+ def _loglikelihood_batch(self, batch: list[Instance]) -> list[tuple[float, bool]]:
93
+ results = []
94
+ for req in batch:
95
+ context, continuation = req.args
96
+
97
+ # CRITICAL: tokenize context+continuation JOINTLY. With ByteLevel BPE,
98
+ # tokenizing the continuation alone mishandles the leading space and
99
+ # word-boundary merges, so the scored tokens wouldn't match what the
100
+ # model actually predicts in context. We find the continuation's token
101
+ # span by encoding the context alone only to measure its length.
102
+ ctx_ids = self.tok_encode(context)
103
+ full_ids = self.tok_encode(context + continuation)
104
+ cont_len = len(full_ids) - len(ctx_ids)
105
+
106
+ # Guard: joint tokenization can merge across the boundary leaving
107
+ # cont_len=0 or even negative. Fall back to scoring the last token.
108
+ if cont_len <= 0:
109
+ cont_len = 1
110
+ if len(full_ids) < cont_len + 1:
111
+ # Sequence too short to score anything meaningful — skip.
112
+ results.append((-float("inf"), False))
113
+ continue
114
+
115
+ all_ids = full_ids
116
+ # Truncate from the left if too long, always keeping the continuation.
117
+ if len(all_ids) > self.max_length:
118
+ all_ids = all_ids[-self.max_length:]
119
+
120
+ input_ids = torch.tensor([all_ids], dtype=torch.long, device=self._device)
121
+
122
+ with torch.no_grad():
123
+ with torch.autocast(device_type=str(self._device).split(":")[0],
124
+ dtype=torch.bfloat16,
125
+ enabled=self._device.type == "cuda"):
126
+ logits, _ = self._model(input_ids)
127
+
128
+ # Log-probs for the continuation tokens only.
129
+ # logits[:, i, :] predicts the token at position i+1, so to score the
130
+ # last cont_len tokens we read logits at [len-cont_len-1 : len-1].
131
+ cont_targets = torch.tensor(all_ids[-cont_len:], device=self._device)
132
+ start = max(0, len(all_ids) - cont_len - 1)
133
+ cont_logits = logits[0, start : start + cont_len, :] # (cont_len, vocab)
134
+
135
+ log_probs = torch.nn.functional.log_softmax(cont_logits.float(), dim=-1)
136
+ token_log_probs = log_probs[range(cont_len), cont_targets]
137
+ total_log_prob = token_log_probs.sum().item()
138
+
139
+ greedy = (cont_logits.argmax(dim=-1) == cont_targets).all().item()
140
+ results.append((total_log_prob, bool(greedy)))
141
+
142
+ return results
143
+
144
+ def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
145
+ """Compute rolling log-likelihood for perplexity tasks."""
146
+ results = []
147
+ for req in requests:
148
+ text = req.args[0]
149
+ ids = self.tok_encode(text)
150
+ total_ll = 0.0
151
+ # Slide a window of max_length over the tokens.
152
+ for start in range(0, max(1, len(ids) - 1), self.max_length):
153
+ chunk = ids[start : start + self.max_length + 1]
154
+ if len(chunk) < 2:
155
+ break
156
+ inp = torch.tensor([chunk[:-1]], dtype=torch.long, device=self._device)
157
+ tgt = torch.tensor(chunk[1:], dtype=torch.long, device=self._device)
158
+ with torch.no_grad():
159
+ with torch.autocast(device_type=str(self._device).split(":")[0],
160
+ dtype=torch.bfloat16,
161
+ enabled=self._device.type == "cuda"):
162
+ logits, _ = self._model(inp)
163
+ log_probs = torch.nn.functional.log_softmax(logits[0].float(), dim=-1)
164
+ total_ll += log_probs[range(len(tgt)), tgt].sum().item()
165
+ results.append(total_ll)
166
+ return results
167
+
168
+ def generate_until(self, requests: list[Instance]) -> list[str]:
169
+ """Greedy generation until stop string (used by some tasks)."""
170
+ results = []
171
+ for req in requests:
172
+ context, gen_kwargs = req.args
173
+ until = gen_kwargs.get("until", ["<|eos|>"])
174
+ max_new = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
175
+ ids = torch.tensor([self.tok_encode(context)], dtype=torch.long,
176
+ device=self._device)
177
+ out = self._model.generate(ids, max_new_tokens=max_new,
178
+ temperature=1.0, top_k=1) # greedy
179
+ new_ids = out[0, ids.shape[1]:].tolist()
180
+ text = self.tok_decode(new_ids)
181
+ for stop in until:
182
+ if stop in text:
183
+ text = text[:text.index(stop)]
184
+ results.append(text)
185
+ return results
186
+
187
+
188
+ # --------------------------------------------------------------------------- #
189
+ # Main
190
+ # --------------------------------------------------------------------------- #
191
+ def main():
192
+ ap = argparse.ArgumentParser()
193
+ ap.add_argument("--checkpoint", required=True)
194
+ ap.add_argument("--tasks", default=DEFAULT_TASKS)
195
+ ap.add_argument("--batch_size", type=int, default=32)
196
+ ap.add_argument("--device", default="cuda")
197
+ ap.add_argument("--output", default="eval_results.json")
198
+ args = ap.parse_args()
199
+
200
+ model = IvmeLM(args.checkpoint, device=args.device, batch_size=args.batch_size)
201
+ task_list = [t.strip() for t in args.tasks.split(",")]
202
+
203
+ print(f"\n[eval] running tasks: {task_list}")
204
+ results = lm_eval.simple_evaluate(
205
+ model=model,
206
+ tasks=task_list,
207
+ num_fewshot=0, # zero-shot, matching the leaderboard
208
+ batch_size=args.batch_size,
209
+ log_samples=False,
210
+ )
211
+
212
+ # Print a clean summary
213
+ print("\n" + "=" * 52)
214
+ print(" İvme-Conversate Eval Results")
215
+ print("=" * 52)
216
+ for task, metrics in results["results"].items():
217
+ acc = metrics.get("acc,none") or metrics.get("acc_norm,none") or 0.0
218
+ print(f" {task:<20} {acc*100:.2f}%")
219
+ print("=" * 52)
220
+ print(f" Model params : {model._model.num_params()/1e6:.1f}M")
221
+ print(f" Checkpoint : {args.checkpoint}")
222
+ print(f" Eval mode : zero-shot")
223
+ print("=" * 52)
224
+
225
+ # Save full results for the model card / leaderboard PR
226
+ with open(args.output, "w") as f:
227
+ json.dump(results["results"], f, indent=2)
228
+ print(f"\n[eval] full results saved -> {args.output}")
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()