gemeo-arch / gemeo_bench.py
timmers's picture
Add gemeo-bench conformance CLI (7 tests from spec §4)
cbd573e verified
#!/usr/bin/env python3
"""
gemeo-bench — GEMEO Architecture v1.0 conformance test suite.
Checks whether a model instance conforms to GEMEO Architecture v1.0
(see gemeo_architecture_spec_v1.md §4). Runs 7 conformance tests.
Usage:
python gemeo_bench.py check <path-to-checkpoint.pt> [--meds <test_dir>]
python gemeo_bench.py check Raras-AI/gemeo-sus-v2 # from HF
Each test returns PASS / FAIL / SKIP with a one-line reason. A model is
GEMEO-conformant if all non-SKIP tests PASS.
This is a reference implementation. It is intentionally dependency-light
(torch + numpy only for the core checks). MEDS-substrate and gap-fill
tests require the test data; they SKIP gracefully if absent.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from dataclasses import dataclass, field
@dataclass
class TestResult:
name: str
status: str # PASS / FAIL / SKIP
reason: str = ""
@dataclass
class ConformanceReport:
instance: str
results: list = field(default_factory=list)
@property
def conformant(self) -> bool:
return all(r.status != "FAIL" for r in self.results)
def add(self, name, status, reason=""):
self.results.append(TestResult(name, status, reason))
def print_report(self):
print("=" * 72)
print(f"GEMEO Architecture v1.0 conformance — {self.instance}")
print("=" * 72)
icons = {"PASS": "✅", "FAIL": "❌", "SKIP": "⊘ "}
for r in self.results:
print(f" {icons.get(r.status, '? ')} {r.name:<38} {r.status:<5} {r.reason}")
print("-" * 72)
verdict = "CONFORMANT ✅" if self.conformant else "NON-CONFORMANT ❌"
n_pass = sum(1 for r in self.results if r.status == "PASS")
n_fail = sum(1 for r in self.results if r.status == "FAIL")
n_skip = sum(1 for r in self.results if r.status == "SKIP")
print(f" Verdict: {verdict} ({n_pass} pass / {n_fail} fail / {n_skip} skip)")
print("=" * 72)
def load_checkpoint(path):
"""Load a checkpoint from a local path or HF repo id."""
import torch
if os.path.exists(path):
return torch.load(path, map_location="cpu", weights_only=False), path
# Try HF
try:
from huggingface_hub import hf_hub_download
# Common ckpt filenames
for fname in ["cdf_v13.pt", "cdf_v13_v3.pt", "model.pt", "pytorch_model.bin"]:
try:
local = hf_hub_download(repo_id=path, filename=fname)
return torch.load(local, map_location="cpu", weights_only=False), local
except Exception:
continue
except ImportError:
pass
raise FileNotFoundError(f"Could not load checkpoint: {path}")
def run_conformance(ckpt_path, meds_dir=None, readme_text=None):
import torch
import numpy as np
report = ConformanceReport(instance=ckpt_path)
ckpt, resolved = load_checkpoint(ckpt_path)
cfg = ckpt.get("config", {})
vocab = ckpt.get("vocab", [])
tok2id = ckpt.get("tok2id", {})
# ---- Test 1: MEDS schema (vocab uses canonical MEDS code prefixes) ----
MEDS_PREFIXES = ("ICD10//", "SIH//", "APAC//", "SIGTAP//", "BPAI//", "ORPHA", "MEDS_")
if vocab:
n_meds = sum(1 for v in vocab if any(v.startswith(p) or p in v for p in MEDS_PREFIXES))
# Forbidden: CID10// (Brazilian non-canonical) instead of ICD10//
has_bad = any(v.startswith("CID10//") for v in vocab)
if has_bad:
report.add("1. MEDS schema", "FAIL", "found non-canonical CID10// prefix (should be ICD10//)")
elif n_meds >= 3:
report.add("1. MEDS schema", "PASS", f"{n_meds} tokens use canonical MEDS prefixes")
else:
report.add("1. MEDS schema", "FAIL", f"only {n_meds} MEDS-prefixed tokens")
else:
report.add("1. MEDS schema", "SKIP", "no vocab in checkpoint")
# ---- Test 2: per-token sigma support ----
try:
sys.path.insert(0, os.path.join(os.path.dirname(resolved), "src"))
sys.path.insert(0, os.path.join(os.path.dirname(resolved), "reference_impl"))
from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config
c = CDFv13Config(**{k: v for k, v in cfg.items() if k in CDFv13Config.__dataclass_fields__})
model = CDFv13Transformer(c)
# Try a forward with a per-token sigma vector
B, T = 2, min(16, c.max_seq_len)
x = torch.randint(0, c.vocab_size, (B, T))
sigma = torch.rand(B, T) # per-token!
cond = torch.zeros(B, dtype=torch.long)
with torch.no_grad():
out = model(x, sigma, cond)
if out.shape == (B, T, c.vocab_size):
report.add("2. Per-token sigma", "PASS", f"accepts (B,T) sigma → logits {tuple(out.shape)}")
else:
report.add("2. Per-token sigma", "FAIL", f"unexpected output shape {tuple(out.shape)}")
except Exception as e:
report.add("2. Per-token sigma", "SKIP", f"cannot instantiate model: {type(e).__name__}: {str(e)[:50]}")
# ---- Test 3: KG gating init (alpha = 0) ----
try:
kg_layers = cfg.get("kg_attn_layers", [])
use_kg = cfg.get("use_kg", False)
if use_kg and kg_layers:
# Inspect alpha gate params if model loaded
try:
model.load_state_dict(ckpt["model_state"])
except Exception:
pass
alphas = [n for n, _ in model.named_parameters() if "alpha" in n.lower()]
report.add("3. KG gating", "PASS",
f"KG cross-attn at layers {kg_layers}; {len(alphas)} gate params present")
else:
report.add("3. KG gating", "FAIL", "use_kg=False or no kg_attn_layers")
except Exception as e:
report.add("3. KG gating", "SKIP", f"{type(e).__name__}")
# ---- Test 4: gap-fill recovery (requires MEDS test data) ----
if meds_dir and os.path.isdir(meds_dir):
report.add("4. Gap-fill recovery", "SKIP",
"MEDS dir present but full gap-fill eval not run in lightweight CLI "
"(use benchmark_v4_clinical.py for the full test)")
else:
report.add("4. Gap-fill recovery", "SKIP", "no MEDS test dir provided (--meds)")
# ---- Test 5: bootstrap-then-learn (module structure) ----
# We check that the architecture exposes the expected inference-mode modules.
impl_dir = None
for d in [os.path.join(os.path.dirname(resolved), "src"),
os.path.join(os.path.dirname(resolved), "reference_impl")]:
if os.path.isdir(d):
impl_dir = d; break
if impl_dir:
files = set(os.listdir(impl_dir))
# The world-model core MUST be present; full twin-stack modules are in the app layer repo.
core_present = "diffusion_forcing_v13.py" in files and "primekg_attention.py" in files
if core_present:
report.add("5. Bootstrap-then-learn", "PASS",
"world-model core present; full mode modules in gemeo-twin-stack")
else:
report.add("5. Bootstrap-then-learn", "FAIL", "missing core modules")
else:
report.add("5. Bootstrap-then-learn", "SKIP", "no src/ dir found alongside checkpoint")
# ---- Test 6: health-system grounding (declared in config or readme) ----
grounding_terms = ["pcdt", "sus", "formulary", "nice", "cms", "ground_sus", "dispensation"]
text = (readme_text or "").lower()
if any(t in text for t in grounding_terms):
report.add("6. Health-system grounding", "PASS", "grounding referenced in model card")
else:
report.add("6. Health-system grounding", "SKIP",
"no README provided; grounding is in gemeo-twin-stack ground_sus.py")
# ---- Test 7: audit citations (model card cites required building blocks) ----
REQUIRED_CITES = {
"diffusion forcing": ["diffusion forcing", "2407.01392", "chen"],
"adaln/dit": ["adaln", "dit", "peebles", "2212.09748"],
"wsd schedule": ["wsd", "minicpm"],
"meds": ["meds", "medical event data standard"],
"primekg": ["primekg", "chandak"],
}
if readme_text:
text = readme_text.lower()
missing = [k for k, terms in REQUIRED_CITES.items() if not any(t in text for t in terms)]
if not missing:
report.add("7. Audit citations", "PASS", "all required building blocks cited")
else:
report.add("7. Audit citations", "FAIL", f"missing citations: {', '.join(missing)}")
else:
report.add("7. Audit citations", "SKIP", "no README text provided")
return report
def main():
ap = argparse.ArgumentParser(description="GEMEO Architecture v1.0 conformance checker")
sub = ap.add_subparsers(dest="cmd")
chk = sub.add_parser("check", help="Run conformance tests on a checkpoint")
chk.add_argument("checkpoint", help="Path to .pt checkpoint or HF repo id")
chk.add_argument("--meds", default=None, help="Path to MEDS test data dir (for gap-fill test)")
chk.add_argument("--readme", default=None, help="Path to model card README (for citation/grounding tests)")
chk.add_argument("--json", action="store_true", help="Emit JSON")
args = ap.parse_args()
if args.cmd != "check":
ap.print_help(); sys.exit(1)
readme_text = None
if args.readme and os.path.exists(args.readme):
readme_text = open(args.readme).read()
else:
# Try README.md next to the checkpoint
cand = os.path.join(os.path.dirname(os.path.abspath(args.checkpoint)), "README.md")
if os.path.exists(cand):
readme_text = open(cand).read()
report = run_conformance(args.checkpoint, meds_dir=args.meds, readme_text=readme_text)
if args.json:
print(json.dumps({"instance": report.instance, "conformant": report.conformant,
"results": [{"name": r.name, "status": r.status, "reason": r.reason}
for r in report.results]}, indent=2))
else:
report.print_report()
sys.exit(0 if report.conformant else 1)
if __name__ == "__main__":
main()