| |
| """ |
| gemeo-bench — GEMEO Architecture v1.0 conformance test suite. |
| |
| Checks whether a model instance conforms to GEMEO Architecture v1.0 |
| (see gemeo_architecture_spec_v1.md §4). Runs 7 conformance tests. |
| |
| Usage: |
| python gemeo_bench.py check <path-to-checkpoint.pt> [--meds <test_dir>] |
| python gemeo_bench.py check Raras-AI/gemeo-sus-v2 # from HF |
| |
| Each test returns PASS / FAIL / SKIP with a one-line reason. A model is |
| GEMEO-conformant if all non-SKIP tests PASS. |
| |
| This is a reference implementation. It is intentionally dependency-light |
| (torch + numpy only for the core checks). MEDS-substrate and gap-fill |
| tests require the test data; they SKIP gracefully if absent. |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import os |
| import sys |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class TestResult: |
| name: str |
| status: str |
| reason: str = "" |
|
|
|
|
| @dataclass |
| class ConformanceReport: |
| instance: str |
| results: list = field(default_factory=list) |
|
|
| @property |
| def conformant(self) -> bool: |
| return all(r.status != "FAIL" for r in self.results) |
|
|
| def add(self, name, status, reason=""): |
| self.results.append(TestResult(name, status, reason)) |
|
|
| def print_report(self): |
| print("=" * 72) |
| print(f"GEMEO Architecture v1.0 conformance — {self.instance}") |
| print("=" * 72) |
| icons = {"PASS": "✅", "FAIL": "❌", "SKIP": "⊘ "} |
| for r in self.results: |
| print(f" {icons.get(r.status, '? ')} {r.name:<38} {r.status:<5} {r.reason}") |
| print("-" * 72) |
| verdict = "CONFORMANT ✅" if self.conformant else "NON-CONFORMANT ❌" |
| n_pass = sum(1 for r in self.results if r.status == "PASS") |
| n_fail = sum(1 for r in self.results if r.status == "FAIL") |
| n_skip = sum(1 for r in self.results if r.status == "SKIP") |
| print(f" Verdict: {verdict} ({n_pass} pass / {n_fail} fail / {n_skip} skip)") |
| print("=" * 72) |
|
|
|
|
| def load_checkpoint(path): |
| """Load a checkpoint from a local path or HF repo id.""" |
| import torch |
| if os.path.exists(path): |
| return torch.load(path, map_location="cpu", weights_only=False), path |
| |
| try: |
| from huggingface_hub import hf_hub_download |
| |
| for fname in ["cdf_v13.pt", "cdf_v13_v3.pt", "model.pt", "pytorch_model.bin"]: |
| try: |
| local = hf_hub_download(repo_id=path, filename=fname) |
| return torch.load(local, map_location="cpu", weights_only=False), local |
| except Exception: |
| continue |
| except ImportError: |
| pass |
| raise FileNotFoundError(f"Could not load checkpoint: {path}") |
|
|
|
|
| def run_conformance(ckpt_path, meds_dir=None, readme_text=None): |
| import torch |
| import numpy as np |
| report = ConformanceReport(instance=ckpt_path) |
| ckpt, resolved = load_checkpoint(ckpt_path) |
| cfg = ckpt.get("config", {}) |
| vocab = ckpt.get("vocab", []) |
| tok2id = ckpt.get("tok2id", {}) |
|
|
| |
| MEDS_PREFIXES = ("ICD10//", "SIH//", "APAC//", "SIGTAP//", "BPAI//", "ORPHA", "MEDS_") |
| if vocab: |
| n_meds = sum(1 for v in vocab if any(v.startswith(p) or p in v for p in MEDS_PREFIXES)) |
| |
| has_bad = any(v.startswith("CID10//") for v in vocab) |
| if has_bad: |
| report.add("1. MEDS schema", "FAIL", "found non-canonical CID10// prefix (should be ICD10//)") |
| elif n_meds >= 3: |
| report.add("1. MEDS schema", "PASS", f"{n_meds} tokens use canonical MEDS prefixes") |
| else: |
| report.add("1. MEDS schema", "FAIL", f"only {n_meds} MEDS-prefixed tokens") |
| else: |
| report.add("1. MEDS schema", "SKIP", "no vocab in checkpoint") |
|
|
| |
| try: |
| sys.path.insert(0, os.path.join(os.path.dirname(resolved), "src")) |
| sys.path.insert(0, os.path.join(os.path.dirname(resolved), "reference_impl")) |
| from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config |
| c = CDFv13Config(**{k: v for k, v in cfg.items() if k in CDFv13Config.__dataclass_fields__}) |
| model = CDFv13Transformer(c) |
| |
| B, T = 2, min(16, c.max_seq_len) |
| x = torch.randint(0, c.vocab_size, (B, T)) |
| sigma = torch.rand(B, T) |
| cond = torch.zeros(B, dtype=torch.long) |
| with torch.no_grad(): |
| out = model(x, sigma, cond) |
| if out.shape == (B, T, c.vocab_size): |
| report.add("2. Per-token sigma", "PASS", f"accepts (B,T) sigma → logits {tuple(out.shape)}") |
| else: |
| report.add("2. Per-token sigma", "FAIL", f"unexpected output shape {tuple(out.shape)}") |
| except Exception as e: |
| report.add("2. Per-token sigma", "SKIP", f"cannot instantiate model: {type(e).__name__}: {str(e)[:50]}") |
|
|
| |
| try: |
| kg_layers = cfg.get("kg_attn_layers", []) |
| use_kg = cfg.get("use_kg", False) |
| if use_kg and kg_layers: |
| |
| try: |
| model.load_state_dict(ckpt["model_state"]) |
| except Exception: |
| pass |
| alphas = [n for n, _ in model.named_parameters() if "alpha" in n.lower()] |
| report.add("3. KG gating", "PASS", |
| f"KG cross-attn at layers {kg_layers}; {len(alphas)} gate params present") |
| else: |
| report.add("3. KG gating", "FAIL", "use_kg=False or no kg_attn_layers") |
| except Exception as e: |
| report.add("3. KG gating", "SKIP", f"{type(e).__name__}") |
|
|
| |
| if meds_dir and os.path.isdir(meds_dir): |
| report.add("4. Gap-fill recovery", "SKIP", |
| "MEDS dir present but full gap-fill eval not run in lightweight CLI " |
| "(use benchmark_v4_clinical.py for the full test)") |
| else: |
| report.add("4. Gap-fill recovery", "SKIP", "no MEDS test dir provided (--meds)") |
|
|
| |
| |
| impl_dir = None |
| for d in [os.path.join(os.path.dirname(resolved), "src"), |
| os.path.join(os.path.dirname(resolved), "reference_impl")]: |
| if os.path.isdir(d): |
| impl_dir = d; break |
| if impl_dir: |
| files = set(os.listdir(impl_dir)) |
| |
| core_present = "diffusion_forcing_v13.py" in files and "primekg_attention.py" in files |
| if core_present: |
| report.add("5. Bootstrap-then-learn", "PASS", |
| "world-model core present; full mode modules in gemeo-twin-stack") |
| else: |
| report.add("5. Bootstrap-then-learn", "FAIL", "missing core modules") |
| else: |
| report.add("5. Bootstrap-then-learn", "SKIP", "no src/ dir found alongside checkpoint") |
|
|
| |
| grounding_terms = ["pcdt", "sus", "formulary", "nice", "cms", "ground_sus", "dispensation"] |
| text = (readme_text or "").lower() |
| if any(t in text for t in grounding_terms): |
| report.add("6. Health-system grounding", "PASS", "grounding referenced in model card") |
| else: |
| report.add("6. Health-system grounding", "SKIP", |
| "no README provided; grounding is in gemeo-twin-stack ground_sus.py") |
|
|
| |
| REQUIRED_CITES = { |
| "diffusion forcing": ["diffusion forcing", "2407.01392", "chen"], |
| "adaln/dit": ["adaln", "dit", "peebles", "2212.09748"], |
| "wsd schedule": ["wsd", "minicpm"], |
| "meds": ["meds", "medical event data standard"], |
| "primekg": ["primekg", "chandak"], |
| } |
| if readme_text: |
| text = readme_text.lower() |
| missing = [k for k, terms in REQUIRED_CITES.items() if not any(t in text for t in terms)] |
| if not missing: |
| report.add("7. Audit citations", "PASS", "all required building blocks cited") |
| else: |
| report.add("7. Audit citations", "FAIL", f"missing citations: {', '.join(missing)}") |
| else: |
| report.add("7. Audit citations", "SKIP", "no README text provided") |
|
|
| return report |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="GEMEO Architecture v1.0 conformance checker") |
| sub = ap.add_subparsers(dest="cmd") |
| chk = sub.add_parser("check", help="Run conformance tests on a checkpoint") |
| chk.add_argument("checkpoint", help="Path to .pt checkpoint or HF repo id") |
| chk.add_argument("--meds", default=None, help="Path to MEDS test data dir (for gap-fill test)") |
| chk.add_argument("--readme", default=None, help="Path to model card README (for citation/grounding tests)") |
| chk.add_argument("--json", action="store_true", help="Emit JSON") |
| args = ap.parse_args() |
|
|
| if args.cmd != "check": |
| ap.print_help(); sys.exit(1) |
|
|
| readme_text = None |
| if args.readme and os.path.exists(args.readme): |
| readme_text = open(args.readme).read() |
| else: |
| |
| cand = os.path.join(os.path.dirname(os.path.abspath(args.checkpoint)), "README.md") |
| if os.path.exists(cand): |
| readme_text = open(cand).read() |
|
|
| report = run_conformance(args.checkpoint, meds_dir=args.meds, readme_text=readme_text) |
| if args.json: |
| print(json.dumps({"instance": report.instance, "conformant": report.conformant, |
| "results": [{"name": r.name, "status": r.status, "reason": r.reason} |
| for r in report.results]}, indent=2)) |
| else: |
| report.print_report() |
| sys.exit(0 if report.conformant else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|