File size: 10,181 Bytes
cbd573e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | #!/usr/bin/env python3
"""
gemeo-bench — GEMEO Architecture v1.0 conformance test suite.
Checks whether a model instance conforms to GEMEO Architecture v1.0
(see gemeo_architecture_spec_v1.md §4). Runs 7 conformance tests.
Usage:
python gemeo_bench.py check <path-to-checkpoint.pt> [--meds <test_dir>]
python gemeo_bench.py check Raras-AI/gemeo-sus-v2 # from HF
Each test returns PASS / FAIL / SKIP with a one-line reason. A model is
GEMEO-conformant if all non-SKIP tests PASS.
This is a reference implementation. It is intentionally dependency-light
(torch + numpy only for the core checks). MEDS-substrate and gap-fill
tests require the test data; they SKIP gracefully if absent.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from dataclasses import dataclass, field
@dataclass
class TestResult:
name: str
status: str # PASS / FAIL / SKIP
reason: str = ""
@dataclass
class ConformanceReport:
instance: str
results: list = field(default_factory=list)
@property
def conformant(self) -> bool:
return all(r.status != "FAIL" for r in self.results)
def add(self, name, status, reason=""):
self.results.append(TestResult(name, status, reason))
def print_report(self):
print("=" * 72)
print(f"GEMEO Architecture v1.0 conformance — {self.instance}")
print("=" * 72)
icons = {"PASS": "✅", "FAIL": "❌", "SKIP": "⊘ "}
for r in self.results:
print(f" {icons.get(r.status, '? ')} {r.name:<38} {r.status:<5} {r.reason}")
print("-" * 72)
verdict = "CONFORMANT ✅" if self.conformant else "NON-CONFORMANT ❌"
n_pass = sum(1 for r in self.results if r.status == "PASS")
n_fail = sum(1 for r in self.results if r.status == "FAIL")
n_skip = sum(1 for r in self.results if r.status == "SKIP")
print(f" Verdict: {verdict} ({n_pass} pass / {n_fail} fail / {n_skip} skip)")
print("=" * 72)
def load_checkpoint(path):
"""Load a checkpoint from a local path or HF repo id."""
import torch
if os.path.exists(path):
return torch.load(path, map_location="cpu", weights_only=False), path
# Try HF
try:
from huggingface_hub import hf_hub_download
# Common ckpt filenames
for fname in ["cdf_v13.pt", "cdf_v13_v3.pt", "model.pt", "pytorch_model.bin"]:
try:
local = hf_hub_download(repo_id=path, filename=fname)
return torch.load(local, map_location="cpu", weights_only=False), local
except Exception:
continue
except ImportError:
pass
raise FileNotFoundError(f"Could not load checkpoint: {path}")
def run_conformance(ckpt_path, meds_dir=None, readme_text=None):
import torch
import numpy as np
report = ConformanceReport(instance=ckpt_path)
ckpt, resolved = load_checkpoint(ckpt_path)
cfg = ckpt.get("config", {})
vocab = ckpt.get("vocab", [])
tok2id = ckpt.get("tok2id", {})
# ---- Test 1: MEDS schema (vocab uses canonical MEDS code prefixes) ----
MEDS_PREFIXES = ("ICD10//", "SIH//", "APAC//", "SIGTAP//", "BPAI//", "ORPHA", "MEDS_")
if vocab:
n_meds = sum(1 for v in vocab if any(v.startswith(p) or p in v for p in MEDS_PREFIXES))
# Forbidden: CID10// (Brazilian non-canonical) instead of ICD10//
has_bad = any(v.startswith("CID10//") for v in vocab)
if has_bad:
report.add("1. MEDS schema", "FAIL", "found non-canonical CID10// prefix (should be ICD10//)")
elif n_meds >= 3:
report.add("1. MEDS schema", "PASS", f"{n_meds} tokens use canonical MEDS prefixes")
else:
report.add("1. MEDS schema", "FAIL", f"only {n_meds} MEDS-prefixed tokens")
else:
report.add("1. MEDS schema", "SKIP", "no vocab in checkpoint")
# ---- Test 2: per-token sigma support ----
try:
sys.path.insert(0, os.path.join(os.path.dirname(resolved), "src"))
sys.path.insert(0, os.path.join(os.path.dirname(resolved), "reference_impl"))
from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config
c = CDFv13Config(**{k: v for k, v in cfg.items() if k in CDFv13Config.__dataclass_fields__})
model = CDFv13Transformer(c)
# Try a forward with a per-token sigma vector
B, T = 2, min(16, c.max_seq_len)
x = torch.randint(0, c.vocab_size, (B, T))
sigma = torch.rand(B, T) # per-token!
cond = torch.zeros(B, dtype=torch.long)
with torch.no_grad():
out = model(x, sigma, cond)
if out.shape == (B, T, c.vocab_size):
report.add("2. Per-token sigma", "PASS", f"accepts (B,T) sigma → logits {tuple(out.shape)}")
else:
report.add("2. Per-token sigma", "FAIL", f"unexpected output shape {tuple(out.shape)}")
except Exception as e:
report.add("2. Per-token sigma", "SKIP", f"cannot instantiate model: {type(e).__name__}: {str(e)[:50]}")
# ---- Test 3: KG gating init (alpha = 0) ----
try:
kg_layers = cfg.get("kg_attn_layers", [])
use_kg = cfg.get("use_kg", False)
if use_kg and kg_layers:
# Inspect alpha gate params if model loaded
try:
model.load_state_dict(ckpt["model_state"])
except Exception:
pass
alphas = [n for n, _ in model.named_parameters() if "alpha" in n.lower()]
report.add("3. KG gating", "PASS",
f"KG cross-attn at layers {kg_layers}; {len(alphas)} gate params present")
else:
report.add("3. KG gating", "FAIL", "use_kg=False or no kg_attn_layers")
except Exception as e:
report.add("3. KG gating", "SKIP", f"{type(e).__name__}")
# ---- Test 4: gap-fill recovery (requires MEDS test data) ----
if meds_dir and os.path.isdir(meds_dir):
report.add("4. Gap-fill recovery", "SKIP",
"MEDS dir present but full gap-fill eval not run in lightweight CLI "
"(use benchmark_v4_clinical.py for the full test)")
else:
report.add("4. Gap-fill recovery", "SKIP", "no MEDS test dir provided (--meds)")
# ---- Test 5: bootstrap-then-learn (module structure) ----
# We check that the architecture exposes the expected inference-mode modules.
impl_dir = None
for d in [os.path.join(os.path.dirname(resolved), "src"),
os.path.join(os.path.dirname(resolved), "reference_impl")]:
if os.path.isdir(d):
impl_dir = d; break
if impl_dir:
files = set(os.listdir(impl_dir))
# The world-model core MUST be present; full twin-stack modules are in the app layer repo.
core_present = "diffusion_forcing_v13.py" in files and "primekg_attention.py" in files
if core_present:
report.add("5. Bootstrap-then-learn", "PASS",
"world-model core present; full mode modules in gemeo-twin-stack")
else:
report.add("5. Bootstrap-then-learn", "FAIL", "missing core modules")
else:
report.add("5. Bootstrap-then-learn", "SKIP", "no src/ dir found alongside checkpoint")
# ---- Test 6: health-system grounding (declared in config or readme) ----
grounding_terms = ["pcdt", "sus", "formulary", "nice", "cms", "ground_sus", "dispensation"]
text = (readme_text or "").lower()
if any(t in text for t in grounding_terms):
report.add("6. Health-system grounding", "PASS", "grounding referenced in model card")
else:
report.add("6. Health-system grounding", "SKIP",
"no README provided; grounding is in gemeo-twin-stack ground_sus.py")
# ---- Test 7: audit citations (model card cites required building blocks) ----
REQUIRED_CITES = {
"diffusion forcing": ["diffusion forcing", "2407.01392", "chen"],
"adaln/dit": ["adaln", "dit", "peebles", "2212.09748"],
"wsd schedule": ["wsd", "minicpm"],
"meds": ["meds", "medical event data standard"],
"primekg": ["primekg", "chandak"],
}
if readme_text:
text = readme_text.lower()
missing = [k for k, terms in REQUIRED_CITES.items() if not any(t in text for t in terms)]
if not missing:
report.add("7. Audit citations", "PASS", "all required building blocks cited")
else:
report.add("7. Audit citations", "FAIL", f"missing citations: {', '.join(missing)}")
else:
report.add("7. Audit citations", "SKIP", "no README text provided")
return report
def main():
ap = argparse.ArgumentParser(description="GEMEO Architecture v1.0 conformance checker")
sub = ap.add_subparsers(dest="cmd")
chk = sub.add_parser("check", help="Run conformance tests on a checkpoint")
chk.add_argument("checkpoint", help="Path to .pt checkpoint or HF repo id")
chk.add_argument("--meds", default=None, help="Path to MEDS test data dir (for gap-fill test)")
chk.add_argument("--readme", default=None, help="Path to model card README (for citation/grounding tests)")
chk.add_argument("--json", action="store_true", help="Emit JSON")
args = ap.parse_args()
if args.cmd != "check":
ap.print_help(); sys.exit(1)
readme_text = None
if args.readme and os.path.exists(args.readme):
readme_text = open(args.readme).read()
else:
# Try README.md next to the checkpoint
cand = os.path.join(os.path.dirname(os.path.abspath(args.checkpoint)), "README.md")
if os.path.exists(cand):
readme_text = open(cand).read()
report = run_conformance(args.checkpoint, meds_dir=args.meds, readme_text=readme_text)
if args.json:
print(json.dumps({"instance": report.instance, "conformant": report.conformant,
"results": [{"name": r.name, "status": r.status, "reason": r.reason}
for r in report.results]}, indent=2))
else:
report.print_report()
sys.exit(0 if report.conformant else 1)
if __name__ == "__main__":
main()
|