| |
| """Local evaluation for AION. |
| |
| These are NOT official leaderboard results. They are local sanity checks plus a |
| small GSM8K sample if the dataset file exists. |
| """ |
| from __future__ import annotations |
| import json |
| import re |
| from pathlib import Path |
| from aion import generate, BASE |
|
|
|
|
| def ok_contains(prompt, needles): |
| out = generate(prompt) |
| low = out.lower() |
| return any(n.lower() in low for n in needles), out |
|
|
|
|
| def extract_number(text: str): |
| |
| ans = text.split("## Answer")[-1] |
| nums = re.findall(r"-?\d+(?:\.\d+)?", ans.replace(",", "")) |
| return nums[-1] if nums else None |
|
|
|
|
| def norm_num(x): |
| try: |
| return float(str(x).replace(",", "")) |
| except Exception: |
| return None |
|
|
|
|
| def run(): |
| results = {} |
| suites = { |
| "chat": [ |
| ("hola", ["hello", "awake"]), |
| ("what can you do", ["html", "math", "python"]), |
| ("who are you", ["aion", "assistant", "learning"]), |
| ], |
| "python": [ |
| ("write code to keep numbers greater than 12", ["filter_greater_than_12", "x > 12"]), |
| ("write a function that filters even numbers from a list", ["filter_even_numbers", "% 2"]), |
| ("load json file", ["json.load", "open"]), |
| ], |
| "web": [ |
| ("create a responsive landing page with dark mode", ["<!doctype html>", "toggle theme", "@media"]), |
| ("build navbar with hamburger menu", ["menu-btn", "aria-expanded", "nav-links"]), |
| ("make todo app with local storage", ["localstorage", "tasks", "render"]), |
| ("fetch api example", ["fetch(", "async", "json"]), |
| ], |
| "math_science": [ |
| ("solve 2x + 5 = 17", ["x = -b/a = 6", "= 6"]), |
| ("derivative of 3x^2+2x+1", ["6x + 2"]), |
| ("integral of 6x^2+4x", ["2x^3", "2x^2"]), |
| ("force mass 10 acceleration 2", ["20 n"]), |
| ("moles mass 10 molar 2", ["5 mol"]), |
| ("what is photosynthesis", ["glucose", "oxygen", "chloroplasts"]), |
| ], |
| } |
| for name, tests in suites.items(): |
| passed = 0 |
| samples = [] |
| for prompt, needles in tests: |
| ok, out = ok_contains(prompt, needles) |
| passed += int(ok) |
| samples.append({"prompt": prompt, "passed": ok, "expected_contains": needles, "output_preview": out[:700]}) |
| results[name] = {"passed": passed, "total": len(tests), "accuracy": passed/len(tests), "samples": samples} |
|
|
| |
| gsm_path = BASE / "outputs" / "unified_learning_ai" / "online_datasets" / "gsm8k_test.jsonl" |
| gsm_samples = [] |
| if gsm_path.exists(): |
| lines = gsm_path.read_text(encoding="utf-8").splitlines()[:30] |
| correct = 0 |
| total = 0 |
| for line in lines: |
| obj = json.loads(line) |
| q = obj["question"] |
| golds = re.findall(r"####\s*([^\n]+)", obj["answer"]) |
| gold = norm_num(golds[-1]) if golds else None |
| out = generate(q) |
| pred = norm_num(extract_number(out)) |
| is_ok = gold is not None and pred is not None and abs(pred - gold) < 1e-6 |
| correct += int(is_ok) |
| total += 1 |
| gsm_samples.append({"question": q[:300], "gold": gold, "pred": pred, "passed": is_ok, "output_preview": out[:500]}) |
| results["gsm8k_test_sample_30_not_official"] = {"passed": correct, "total": total, "accuracy": correct/total if total else 0, "samples": gsm_samples} |
|
|
| out_dir = BASE / "results" |
| out_dir.mkdir(exist_ok=True) |
| (out_dir / "aion_local_eval.json").write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8") |
| |
| lines = ["# AION Local Evaluation", "", "These are local sanity checks, not official HF leaderboard results.", ""] |
| for k,v in results.items(): |
| lines.append(f"- **{k}**: {v['passed']}/{v['total']} = {v['accuracy']:.2%}") |
| (out_dir / "aion_local_eval.md").write_text("\n".join(lines), encoding="utf-8") |
| print(json.dumps({k:{"passed":v["passed"],"total":v["total"],"accuracy":v["accuracy"]} for k,v in results.items()}, indent=2)) |
|
|
| if __name__ == "__main__": |
| run() |
|
|