--- datasets: - kreasof-ai/ECA-Zero --- ``` import re import torch import pandas as pd from tqdm import tqdm from collections import defaultdict from datasets import load_dataset from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast import fla from fla.models import path_attn # <-- Add this line # --- Configuration --- MODEL_ID = "THIS REPO" DATASET_ID = "kreasof-ai/ECA-Zero" BATCH_SIZE = 128 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # From the dataset generation script WOLFRAM_CLASSES_MAP = { 1: [0, 8, 32, 40, 128, 136, 160, 168], 2: [1, 19, 23, 29, 37, 50, 108, 178], 3: [30, 45, 60, 90, 105, 126, 150], 4: [54, 106, 110, 124, 137, 147, 193] } # Invert for fast lookup: Rule -> Class RULE_TO_CLASS = {} for cls, rules in WOLFRAM_CLASSES_MAP.items(): for r in rules: RULE_TO_CLASS[r] = cls class ECAVerifier: def __init__(self): self.re_rule = re.compile(r"Rule: (\d+)") self.re_start = re.compile(r"Start: ([01]+)") self.re_end = re.compile(r"End: ([01]+)") self.re_steps = re.compile(r"Steps: (\d+)") self.re_hint_class = re.compile(r"Hint: Class (\d)") self.re_tt = re.compile(r"([01]{3})->([01])") def get_wolfram_class(self, prompt): # 1. Check for explicit Hint (Induction tasks) m = self.re_hint_class.search(prompt) if m: return int(m.group(1)) # 2. Check for Rule ID (Deduction/Abduction) and look up m = self.re_rule.search(prompt) if m: rule = int(m.group(1)) return RULE_TO_CLASS.get(rule, 0) # 0 = Unknown/Other return 0 def get_next_state(self, state, rule): next_state = [] L = len(state) for i in range(L): l, c, r = state[(i - 1) % L], state[i], state[(i + 1) % L] pattern = (l << 2) | (c << 1) | r bit = 1 if (rule & (1 << pattern)) else 0 next_state.append(bit) return next_state def simulate(self, start_state, rule, steps): current = list(start_state) for _ in range(steps): current = self.get_next_state(current, rule) return current def parse_rule_string(self, text): matches = self.re_tt.findall(text) if not matches: return None rule = 0 for pat, res in matches: if res == '1': rule |= (1 << int(pat, 2)) return rule def verify(self, task_type, prompt, model_output_str): try: steps = int(self.re_steps.search(prompt).group(1)) start_match = self.re_start.search(prompt) start_state = [int(x) for x in start_match.group(1)] if start_match else None end_match = self.re_end.search(prompt) end_state = [int(x) for x in end_match.group(1)] if end_match else None rule_match = self.re_rule.search(prompt) rule = int(rule_match.group(1)) if rule_match else None except AttributeError: return False answer = model_output_str.strip() try: if task_type == 'deduction': pred_state = [int(x) for x in answer if x in '01'] if not pred_state: return False expected = self.simulate(start_state, rule, steps) return pred_state == expected elif task_type == 'induction': pred_rule = self.parse_rule_string(answer) if pred_rule is None: return False sim_end = self.simulate(start_state, pred_rule, steps) return sim_end == end_state elif task_type == 'abduction': pred_start = [int(x) for x in answer if x in '01'] if not pred_start or len(pred_start) != len(end_state): return False sim_end = self.simulate(pred_start, rule, steps) return sim_end == end_state except Exception: return False return False def main(): print(f"Loading tokenizer from {MODEL_ID}...") try: tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID) except: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading model from {MODEL_ID}...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map=DEVICE, ) print("Compiling the model") model = torch.compile(model) model.eval() print("Loading Test Set...") dataset = load_dataset(DATASET_ID, split="test") verifier = ECAVerifier() # Storage: results[task][class_id] = [True, False, ...] results = defaultdict(lambda: defaultdict(list)) print("Starting Stratified Evaluation...") for i in tqdm(range(0, len(dataset), BATCH_SIZE)): batch = dataset[i : i + BATCH_SIZE] tasks = batch['task'] inputs = batch['input'] prompts = [f"{tokenizer.bos_token}{inp}\n\n" for inp in inputs] # FIX: Added return_token_type_ids=False encodings = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=2048, return_token_type_ids=False, ).to(DEVICE) with torch.no_grad(): generated_ids = model.generate( input_ids=encodings['input_ids'], max_new_tokens=2048, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) for j, raw_output in enumerate(decoded_outputs): if "" in raw_output: final_answer = raw_output.split("")[-1].replace(tokenizer.eos_token, "").strip() else: final_answer = "" # Determine Class w_class = verifier.get_wolfram_class(inputs[j]) # Verify is_correct = verifier.verify(tasks[j], inputs[j], final_answer) # Store results[tasks[j]][w_class].append(is_correct) results[tasks[j]]["ALL"].append(is_correct) # --- Print Report --- print("\n" + "="*60) print("STRATIFIED RESULTS (Accuracy by Wolfram Class)") print("="*60) # Define column headers print(f"{'Task':<12} | {'Class 1':<10} | {'Class 2':<10} | {'Class 3':<10} | {'Class 4':<10} | {'OVERALL':<10}") print("-" * 75) for task in ["deduction", "induction", "abduction"]: row_str = f"{task.capitalize():<12} | " for c in [1, 2, 3, 4]: outcomes = results[task][c] if outcomes: acc = sum(outcomes) / len(outcomes) row_str += f"{acc:.1%} ({len(outcomes):<3}) | " # concise else: row_str += "N/A | " # Overall all_outcomes = results[task]["ALL"] if all_outcomes: total_acc = sum(all_outcomes) / len(all_outcomes) row_str += f"{total_acc:.1%} ({len(all_outcomes)})" print(row_str) print("="*60) print("Class Legend:") print("1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)") if __name__ == "__main__": main() ``` ``` ============================================================ STRATIFIED RESULTS (Accuracy by Wolfram Class) ============================================================ Task | Class 1 | Class 2 | Class 3 | Class 4 | OVERALL --------------------------------------------------------------------------- Deduction | 15.9% (113) | 8.4% (226) | 2.7% (412) | 2.4% (410) | 5.0% (1161) Induction | 6.2% (113) | 5.3% (227) | 6.3% (414) | 9.2% (411) | 7.1% (1165) Abduction | 6.4% (47 ) | 8.6% (185) | 7.2% (388) | 9.8% (387) | 8.4% (1007) ============================================================ Class Legend: 1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest) ```