#!/usr/bin/env python3 """Local evaluation for AION. These are NOT official leaderboard results. They are local sanity checks plus a small GSM8K sample if the dataset file exists. """ from __future__ import annotations import json import re from pathlib import Path from aion import generate, BASE def ok_contains(prompt, needles): out = generate(prompt) low = out.lower() return any(n.lower() in low for n in needles), out def extract_number(text: str): # Prefer last number in Answer section. ans = text.split("## Answer")[-1] nums = re.findall(r"-?\d+(?:\.\d+)?", ans.replace(",", "")) return nums[-1] if nums else None def norm_num(x): try: return float(str(x).replace(",", "")) except Exception: return None def run(): results = {} suites = { "chat": [ ("hola", ["hello", "awake"]), ("what can you do", ["html", "math", "python"]), ("who are you", ["aion", "assistant", "learning"]), ], "python": [ ("write code to keep numbers greater than 12", ["filter_greater_than_12", "x > 12"]), ("write a function that filters even numbers from a list", ["filter_even_numbers", "% 2"]), ("load json file", ["json.load", "open"]), ], "web": [ ("create a responsive landing page with dark mode", ["", "toggle theme", "@media"]), ("build navbar with hamburger menu", ["menu-btn", "aria-expanded", "nav-links"]), ("make todo app with local storage", ["localstorage", "tasks", "render"]), ("fetch api example", ["fetch(", "async", "json"]), ], "math_science": [ ("solve 2x + 5 = 17", ["x = -b/a = 6", "= 6"]), ("derivative of 3x^2+2x+1", ["6x + 2"]), ("integral of 6x^2+4x", ["2x^3", "2x^2"]), ("force mass 10 acceleration 2", ["20 n"]), ("moles mass 10 molar 2", ["5 mol"]), ("what is photosynthesis", ["glucose", "oxygen", "chloroplasts"]), ], } for name, tests in suites.items(): passed = 0 samples = [] for prompt, needles in tests: ok, out = ok_contains(prompt, needles) passed += int(ok) samples.append({"prompt": prompt, "passed": ok, "expected_contains": needles, "output_preview": out[:700]}) results[name] = {"passed": passed, "total": len(tests), "accuracy": passed/len(tests), "samples": samples} # Small GSM8K sample: official dataset format, local tiny subset only. gsm_path = BASE / "outputs" / "unified_learning_ai" / "online_datasets" / "gsm8k_test.jsonl" gsm_samples = [] if gsm_path.exists(): lines = gsm_path.read_text(encoding="utf-8").splitlines()[:30] correct = 0 total = 0 for line in lines: obj = json.loads(line) q = obj["question"] golds = re.findall(r"####\s*([^\n]+)", obj["answer"]) gold = norm_num(golds[-1]) if golds else None out = generate(q) pred = norm_num(extract_number(out)) is_ok = gold is not None and pred is not None and abs(pred - gold) < 1e-6 correct += int(is_ok) total += 1 gsm_samples.append({"question": q[:300], "gold": gold, "pred": pred, "passed": is_ok, "output_preview": out[:500]}) results["gsm8k_test_sample_30_not_official"] = {"passed": correct, "total": total, "accuracy": correct/total if total else 0, "samples": gsm_samples} out_dir = BASE / "results" out_dir.mkdir(exist_ok=True) (out_dir / "aion_local_eval.json").write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8") # markdown summary lines = ["# AION Local Evaluation", "", "These are local sanity checks, not official HF leaderboard results.", ""] for k,v in results.items(): lines.append(f"- **{k}**: {v['passed']}/{v['total']} = {v['accuracy']:.2%}") (out_dir / "aion_local_eval.md").write_text("\n".join(lines), encoding="utf-8") print(json.dumps({k:{"passed":v["passed"],"total":v["total"],"accuracy":v["accuracy"]} for k,v in results.items()}, indent=2)) if __name__ == "__main__": run()