amirali1985 commited on
Commit
241a3ad
·
verified ·
1 Parent(s): 216e5a0

Upload modular/code/evaluate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modular/code/evaluate.py +62 -0
modular/code/evaluate.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular arithmetic evaluator — single-token accuracy via recursion.
3
+
4
+ Usage:
5
+ from arithmetic.modular.training.evaluate import ModularEvaluator
6
+ ev = ModularEvaluator(model, device="cuda", K=1)
7
+ acc = ev.run(test_examples)
8
+ """
9
+ import torch
10
+ from typing import List
11
+
12
+ from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len
13
+ from arithmetic.modular.data.modular import ModularExample, PROMPT_LEN, PAD
14
+
15
+
16
+ class ModularEvaluator:
17
+
18
+ def __init__(self, model, device: str = "cuda", K: int = 1):
19
+ self.model = model
20
+ self.device = device
21
+ self.K = K
22
+ self.base_v = int(model.vocab_sizes[0].item())
23
+
24
+ @torch.no_grad()
25
+ def run(self, examples: List[ModularExample], max_examples: int = 0) -> float:
26
+ self.model.eval()
27
+ if max_examples > 0:
28
+ examples = examples[:max_examples]
29
+ correct = 0
30
+ for ex in examples:
31
+ ids = torch.tensor(ex.tokens, dtype=torch.long, device=self.device).unsqueeze(0)
32
+ attn = torch.ones_like(ids)
33
+ pl = torch.tensor([PROMPT_LEN], dtype=torch.long, device=self.device)
34
+
35
+ im = infer_insert_mask(ids, self.K, attn)
36
+ ep = expand_prompt_len(pl, im)
37
+ ed, ea = insert_tokens_with_padding(ids, attn, im, self.base_v, PAD)
38
+
39
+ data, _, _ = self.model.recursion(
40
+ ed, ea, max_iterations=2,
41
+ memory_span_abs=512, memory_span_traj=512,
42
+ temperature=0.0, prompt_len=ep,
43
+ )
44
+
45
+ # Forward pass to get logits on the filled sequence
46
+ block_mask = self.model._create_sorl_block_mask(data, 512, 512)
47
+ out = self.model.model.forward(
48
+ input_ids=data, attention_mask=ea,
49
+ block_mask=block_mask, use_cache=False,
50
+ )
51
+ logits = out.logits
52
+
53
+ # Result token is the (PROMPT_LEN)-th trajectory token (0-indexed)
54
+ is_traj = data[0] < self.base_v
55
+ traj_pos = is_traj.nonzero(as_tuple=True)[0]
56
+ result_pos = traj_pos[PROMPT_LEN].item()
57
+ pred = logits[0, result_pos - 1, :self.base_v].argmax().item()
58
+
59
+ if pred == ex.result:
60
+ correct += 1
61
+
62
+ return correct / max(len(examples), 1)