#!/usr/bin/env python3 """ gemeo-bench — GEMEO Architecture v1.0 conformance test suite. Checks whether a model instance conforms to GEMEO Architecture v1.0 (see gemeo_architecture_spec_v1.md §4). Runs 7 conformance tests. Usage: python gemeo_bench.py check [--meds ] python gemeo_bench.py check Raras-AI/gemeo-sus-v2 # from HF Each test returns PASS / FAIL / SKIP with a one-line reason. A model is GEMEO-conformant if all non-SKIP tests PASS. This is a reference implementation. It is intentionally dependency-light (torch + numpy only for the core checks). MEDS-substrate and gap-fill tests require the test data; they SKIP gracefully if absent. """ from __future__ import annotations import argparse import json import os import sys from dataclasses import dataclass, field @dataclass class TestResult: name: str status: str # PASS / FAIL / SKIP reason: str = "" @dataclass class ConformanceReport: instance: str results: list = field(default_factory=list) @property def conformant(self) -> bool: return all(r.status != "FAIL" for r in self.results) def add(self, name, status, reason=""): self.results.append(TestResult(name, status, reason)) def print_report(self): print("=" * 72) print(f"GEMEO Architecture v1.0 conformance — {self.instance}") print("=" * 72) icons = {"PASS": "✅", "FAIL": "❌", "SKIP": "⊘ "} for r in self.results: print(f" {icons.get(r.status, '? ')} {r.name:<38} {r.status:<5} {r.reason}") print("-" * 72) verdict = "CONFORMANT ✅" if self.conformant else "NON-CONFORMANT ❌" n_pass = sum(1 for r in self.results if r.status == "PASS") n_fail = sum(1 for r in self.results if r.status == "FAIL") n_skip = sum(1 for r in self.results if r.status == "SKIP") print(f" Verdict: {verdict} ({n_pass} pass / {n_fail} fail / {n_skip} skip)") print("=" * 72) def load_checkpoint(path): """Load a checkpoint from a local path or HF repo id.""" import torch if os.path.exists(path): return torch.load(path, map_location="cpu", weights_only=False), path # Try HF try: from huggingface_hub import hf_hub_download # Common ckpt filenames for fname in ["cdf_v13.pt", "cdf_v13_v3.pt", "model.pt", "pytorch_model.bin"]: try: local = hf_hub_download(repo_id=path, filename=fname) return torch.load(local, map_location="cpu", weights_only=False), local except Exception: continue except ImportError: pass raise FileNotFoundError(f"Could not load checkpoint: {path}") def run_conformance(ckpt_path, meds_dir=None, readme_text=None): import torch import numpy as np report = ConformanceReport(instance=ckpt_path) ckpt, resolved = load_checkpoint(ckpt_path) cfg = ckpt.get("config", {}) vocab = ckpt.get("vocab", []) tok2id = ckpt.get("tok2id", {}) # ---- Test 1: MEDS schema (vocab uses canonical MEDS code prefixes) ---- MEDS_PREFIXES = ("ICD10//", "SIH//", "APAC//", "SIGTAP//", "BPAI//", "ORPHA", "MEDS_") if vocab: n_meds = sum(1 for v in vocab if any(v.startswith(p) or p in v for p in MEDS_PREFIXES)) # Forbidden: CID10// (Brazilian non-canonical) instead of ICD10// has_bad = any(v.startswith("CID10//") for v in vocab) if has_bad: report.add("1. MEDS schema", "FAIL", "found non-canonical CID10// prefix (should be ICD10//)") elif n_meds >= 3: report.add("1. MEDS schema", "PASS", f"{n_meds} tokens use canonical MEDS prefixes") else: report.add("1. MEDS schema", "FAIL", f"only {n_meds} MEDS-prefixed tokens") else: report.add("1. MEDS schema", "SKIP", "no vocab in checkpoint") # ---- Test 2: per-token sigma support ---- try: sys.path.insert(0, os.path.join(os.path.dirname(resolved), "src")) sys.path.insert(0, os.path.join(os.path.dirname(resolved), "reference_impl")) from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config c = CDFv13Config(**{k: v for k, v in cfg.items() if k in CDFv13Config.__dataclass_fields__}) model = CDFv13Transformer(c) # Try a forward with a per-token sigma vector B, T = 2, min(16, c.max_seq_len) x = torch.randint(0, c.vocab_size, (B, T)) sigma = torch.rand(B, T) # per-token! cond = torch.zeros(B, dtype=torch.long) with torch.no_grad(): out = model(x, sigma, cond) if out.shape == (B, T, c.vocab_size): report.add("2. Per-token sigma", "PASS", f"accepts (B,T) sigma → logits {tuple(out.shape)}") else: report.add("2. Per-token sigma", "FAIL", f"unexpected output shape {tuple(out.shape)}") except Exception as e: report.add("2. Per-token sigma", "SKIP", f"cannot instantiate model: {type(e).__name__}: {str(e)[:50]}") # ---- Test 3: KG gating init (alpha = 0) ---- try: kg_layers = cfg.get("kg_attn_layers", []) use_kg = cfg.get("use_kg", False) if use_kg and kg_layers: # Inspect alpha gate params if model loaded try: model.load_state_dict(ckpt["model_state"]) except Exception: pass alphas = [n for n, _ in model.named_parameters() if "alpha" in n.lower()] report.add("3. KG gating", "PASS", f"KG cross-attn at layers {kg_layers}; {len(alphas)} gate params present") else: report.add("3. KG gating", "FAIL", "use_kg=False or no kg_attn_layers") except Exception as e: report.add("3. KG gating", "SKIP", f"{type(e).__name__}") # ---- Test 4: gap-fill recovery (requires MEDS test data) ---- if meds_dir and os.path.isdir(meds_dir): report.add("4. Gap-fill recovery", "SKIP", "MEDS dir present but full gap-fill eval not run in lightweight CLI " "(use benchmark_v4_clinical.py for the full test)") else: report.add("4. Gap-fill recovery", "SKIP", "no MEDS test dir provided (--meds)") # ---- Test 5: bootstrap-then-learn (module structure) ---- # We check that the architecture exposes the expected inference-mode modules. impl_dir = None for d in [os.path.join(os.path.dirname(resolved), "src"), os.path.join(os.path.dirname(resolved), "reference_impl")]: if os.path.isdir(d): impl_dir = d; break if impl_dir: files = set(os.listdir(impl_dir)) # The world-model core MUST be present; full twin-stack modules are in the app layer repo. core_present = "diffusion_forcing_v13.py" in files and "primekg_attention.py" in files if core_present: report.add("5. Bootstrap-then-learn", "PASS", "world-model core present; full mode modules in gemeo-twin-stack") else: report.add("5. Bootstrap-then-learn", "FAIL", "missing core modules") else: report.add("5. Bootstrap-then-learn", "SKIP", "no src/ dir found alongside checkpoint") # ---- Test 6: health-system grounding (declared in config or readme) ---- grounding_terms = ["pcdt", "sus", "formulary", "nice", "cms", "ground_sus", "dispensation"] text = (readme_text or "").lower() if any(t in text for t in grounding_terms): report.add("6. Health-system grounding", "PASS", "grounding referenced in model card") else: report.add("6. Health-system grounding", "SKIP", "no README provided; grounding is in gemeo-twin-stack ground_sus.py") # ---- Test 7: audit citations (model card cites required building blocks) ---- REQUIRED_CITES = { "diffusion forcing": ["diffusion forcing", "2407.01392", "chen"], "adaln/dit": ["adaln", "dit", "peebles", "2212.09748"], "wsd schedule": ["wsd", "minicpm"], "meds": ["meds", "medical event data standard"], "primekg": ["primekg", "chandak"], } if readme_text: text = readme_text.lower() missing = [k for k, terms in REQUIRED_CITES.items() if not any(t in text for t in terms)] if not missing: report.add("7. Audit citations", "PASS", "all required building blocks cited") else: report.add("7. Audit citations", "FAIL", f"missing citations: {', '.join(missing)}") else: report.add("7. Audit citations", "SKIP", "no README text provided") return report def main(): ap = argparse.ArgumentParser(description="GEMEO Architecture v1.0 conformance checker") sub = ap.add_subparsers(dest="cmd") chk = sub.add_parser("check", help="Run conformance tests on a checkpoint") chk.add_argument("checkpoint", help="Path to .pt checkpoint or HF repo id") chk.add_argument("--meds", default=None, help="Path to MEDS test data dir (for gap-fill test)") chk.add_argument("--readme", default=None, help="Path to model card README (for citation/grounding tests)") chk.add_argument("--json", action="store_true", help="Emit JSON") args = ap.parse_args() if args.cmd != "check": ap.print_help(); sys.exit(1) readme_text = None if args.readme and os.path.exists(args.readme): readme_text = open(args.readme).read() else: # Try README.md next to the checkpoint cand = os.path.join(os.path.dirname(os.path.abspath(args.checkpoint)), "README.md") if os.path.exists(cand): readme_text = open(cand).read() report = run_conformance(args.checkpoint, meds_dir=args.meds, readme_text=readme_text) if args.json: print(json.dumps({"instance": report.instance, "conformant": report.conformant, "results": [{"name": r.name, "status": r.status, "reason": r.reason} for r in report.results]}, indent=2)) else: report.print_report() sys.exit(0 if report.conformant else 1) if __name__ == "__main__": main()