ChavyvAkvar commited on
Commit
f3f7555
·
verified ·
1 Parent(s): 716716a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +236 -0
README.md ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - kreasof-ai/ECA-Zero
4
+ ---
5
+ ```
6
+ import re
7
+ import torch
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ from collections import defaultdict
11
+ from datasets import load_dataset
12
+ from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
13
+
14
+ # --- Configuration ---
15
+ MODEL_ID = "THIS REPO"
16
+ DATASET_ID = "kreasof-ai/ECA-Zero"
17
+ BATCH_SIZE = 64
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # From the dataset generation script
21
+ WOLFRAM_CLASSES_MAP = {
22
+ 1: [0, 8, 32, 40, 128, 136, 160, 168],
23
+ 2: [1, 19, 23, 29, 37, 50, 108, 178],
24
+ 3: [30, 45, 60, 90, 105, 126, 150],
25
+ 4: [54, 106, 110, 124, 137, 147, 193]
26
+ }
27
+
28
+ # Invert for fast lookup: Rule -> Class
29
+ RULE_TO_CLASS = {}
30
+ for cls, rules in WOLFRAM_CLASSES_MAP.items():
31
+ for r in rules:
32
+ RULE_TO_CLASS[r] = cls
33
+
34
+ class ECAVerifier:
35
+ def __init__(self):
36
+ self.re_rule = re.compile(r"Rule: (\d+)")
37
+ self.re_start = re.compile(r"Start: ([01]+)")
38
+ self.re_end = re.compile(r"End: ([01]+)")
39
+ self.re_steps = re.compile(r"Steps: (\d+)")
40
+ self.re_hint_class = re.compile(r"Hint: Class (\d)")
41
+ self.re_tt = re.compile(r"([01]{3})->([01])")
42
+
43
+ def get_wolfram_class(self, prompt):
44
+ # 1. Check for explicit Hint (Induction tasks)
45
+ m = self.re_hint_class.search(prompt)
46
+ if m:
47
+ return int(m.group(1))
48
+
49
+ # 2. Check for Rule ID (Deduction/Abduction) and look up
50
+ m = self.re_rule.search(prompt)
51
+ if m:
52
+ rule = int(m.group(1))
53
+ return RULE_TO_CLASS.get(rule, 0) # 0 = Unknown/Other
54
+
55
+ return 0
56
+
57
+ def get_next_state(self, state, rule):
58
+ next_state = []
59
+ L = len(state)
60
+ for i in range(L):
61
+ l, c, r = state[(i - 1) % L], state[i], state[(i + 1) % L]
62
+ pattern = (l << 2) | (c << 1) | r
63
+ bit = 1 if (rule & (1 << pattern)) else 0
64
+ next_state.append(bit)
65
+ return next_state
66
+
67
+ def simulate(self, start_state, rule, steps):
68
+ current = list(start_state)
69
+ for _ in range(steps):
70
+ current = self.get_next_state(current, rule)
71
+ return current
72
+
73
+ def parse_rule_string(self, text):
74
+ matches = self.re_tt.findall(text)
75
+ if not matches: return None
76
+ rule = 0
77
+ for pat, res in matches:
78
+ if res == '1': rule |= (1 << int(pat, 2))
79
+ return rule
80
+
81
+ def verify(self, task_type, prompt, model_output_str):
82
+ try:
83
+ steps = int(self.re_steps.search(prompt).group(1))
84
+ start_match = self.re_start.search(prompt)
85
+ start_state = [int(x) for x in start_match.group(1)] if start_match else None
86
+ end_match = self.re_end.search(prompt)
87
+ end_state = [int(x) for x in end_match.group(1)] if end_match else None
88
+ rule_match = self.re_rule.search(prompt)
89
+ rule = int(rule_match.group(1)) if rule_match else None
90
+ except AttributeError:
91
+ return False
92
+
93
+ answer = model_output_str.strip()
94
+ try:
95
+ if task_type == 'deduction':
96
+ pred_state = [int(x) for x in answer if x in '01']
97
+ if not pred_state: return False
98
+ expected = self.simulate(start_state, rule, steps)
99
+ return pred_state == expected
100
+
101
+ elif task_type == 'induction':
102
+ pred_rule = self.parse_rule_string(answer)
103
+ if pred_rule is None: return False
104
+ sim_end = self.simulate(start_state, pred_rule, steps)
105
+ return sim_end == end_state
106
+
107
+ elif task_type == 'abduction':
108
+ pred_start = [int(x) for x in answer if x in '01']
109
+ if not pred_start or len(pred_start) != len(end_state): return False
110
+ sim_end = self.simulate(pred_start, rule, steps)
111
+ return sim_end == end_state
112
+ except Exception:
113
+ return False
114
+ return False
115
+
116
+ def main():
117
+ print(f"Loading tokenizer from {MODEL_ID}...")
118
+ try:
119
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID)
120
+ except:
121
+ from transformers import AutoTokenizer
122
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
123
+
124
+ if tokenizer.pad_token is None:
125
+ tokenizer.pad_token = tokenizer.eos_token
126
+
127
+ print(f"Loading model from {MODEL_ID}...")
128
+ model = AutoModelForCausalLM.from_pretrained(
129
+ MODEL_ID,
130
+ torch_dtype=torch.bfloat16,
131
+ device_map=DEVICE,
132
+ )
133
+ model.eval()
134
+
135
+ print("Loading Test Set...")
136
+ dataset = load_dataset(DATASET_ID, split="test")
137
+ verifier = ECAVerifier()
138
+
139
+ # Storage: results[task][class_id] = [True, False, ...]
140
+ results = defaultdict(lambda: defaultdict(list))
141
+
142
+ print("Starting Stratified Evaluation...")
143
+
144
+ for i in tqdm(range(0, len(dataset), BATCH_SIZE)):
145
+ batch = dataset[i : i + BATCH_SIZE]
146
+ tasks = batch['task']
147
+ inputs = batch['input']
148
+
149
+ prompts = [f"{tokenizer.bos_token}{inp}\n<think>\n" for inp in inputs]
150
+
151
+ # FIX: Added return_token_type_ids=False
152
+ encodings = tokenizer(
153
+ prompts,
154
+ return_tensors="pt",
155
+ padding=True,
156
+ truncation=True,
157
+ max_length=2048,
158
+ return_token_type_ids=False
159
+ ).to(DEVICE)
160
+
161
+ with torch.no_grad():
162
+ generated_ids = model.generate(
163
+ **encodings,
164
+ max_new_tokens=2048,
165
+ do_sample=False,
166
+ pad_token_id=tokenizer.pad_token_id,
167
+ eos_token_id=tokenizer.eos_token_id
168
+ )
169
+
170
+ decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
171
+
172
+ for j, raw_output in enumerate(decoded_outputs):
173
+ if "</think>" in raw_output:
174
+ final_answer = raw_output.split("</think>")[-1].replace(tokenizer.eos_token, "").strip()
175
+ else:
176
+ final_answer = ""
177
+
178
+ # Determine Class
179
+ w_class = verifier.get_wolfram_class(inputs[j])
180
+
181
+ # Verify
182
+ is_correct = verifier.verify(tasks[j], inputs[j], final_answer)
183
+
184
+ # Store
185
+ results[tasks[j]][w_class].append(is_correct)
186
+ results[tasks[j]]["ALL"].append(is_correct)
187
+
188
+ # --- Print Report ---
189
+ print("\n" + "="*60)
190
+ print("STRATIFIED RESULTS (Accuracy by Wolfram Class)")
191
+ print("="*60)
192
+
193
+ # Define column headers
194
+ print(f"{'Task':<12} | {'Class 1':<10} | {'Class 2':<10} | {'Class 3':<10} | {'Class 4':<10} | {'OVERALL':<10}")
195
+ print("-" * 75)
196
+
197
+ for task in ["deduction", "induction", "abduction"]:
198
+ row_str = f"{task.capitalize():<12} | "
199
+
200
+ for c in [1, 2, 3, 4]:
201
+ outcomes = results[task][c]
202
+ if outcomes:
203
+ acc = sum(outcomes) / len(outcomes)
204
+ row_str += f"{acc:.1%} ({len(outcomes):<3}) | " # concise
205
+ else:
206
+ row_str += "N/A | "
207
+
208
+ # Overall
209
+ all_outcomes = results[task]["ALL"]
210
+ if all_outcomes:
211
+ total_acc = sum(all_outcomes) / len(all_outcomes)
212
+ row_str += f"{total_acc:.1%} ({len(all_outcomes)})"
213
+
214
+ print(row_str)
215
+
216
+ print("="*60)
217
+ print("Class Legend:")
218
+ print("1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)")
219
+
220
+ if __name__ == "__main__":
221
+ main()
222
+ ```
223
+
224
+ ```
225
+ ============================================================
226
+ STRATIFIED RESULTS (Accuracy by Wolfram Class)
227
+ ============================================================
228
+ Task | Class 1 | Class 2 | Class 3 | Class 4 | OVERALL
229
+ ---------------------------------------------------------------------------
230
+ Deduction | 32.7% (113) | 20.4% (226) | 24.8% (412) | 22.0% (410) | 23.7% (1161)
231
+ Induction | 98.2% (113) | 79.7% (227) | 83.1% (414) | 79.1% (411) | 82.5% (1165)
232
+ Abduction | 17.0% (47 ) | 29.7% (185) | 19.8% (388) | 18.3% (387) | 21.0% (1007)
233
+ ============================================================
234
+ Class Legend:
235
+ 1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)
236
+ ```