File size: 10,181 Bytes
cbd573e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#!/usr/bin/env python3
"""
gemeo-bench — GEMEO Architecture v1.0 conformance test suite.

Checks whether a model instance conforms to GEMEO Architecture v1.0
(see gemeo_architecture_spec_v1.md §4). Runs 7 conformance tests.

Usage:
    python gemeo_bench.py check <path-to-checkpoint.pt> [--meds <test_dir>]
    python gemeo_bench.py check Raras-AI/gemeo-sus-v2   # from HF

Each test returns PASS / FAIL / SKIP with a one-line reason. A model is
GEMEO-conformant if all non-SKIP tests PASS.

This is a reference implementation. It is intentionally dependency-light
(torch + numpy only for the core checks). MEDS-substrate and gap-fill
tests require the test data; they SKIP gracefully if absent.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from dataclasses import dataclass, field


@dataclass
class TestResult:
    name: str
    status: str  # PASS / FAIL / SKIP
    reason: str = ""


@dataclass
class ConformanceReport:
    instance: str
    results: list = field(default_factory=list)

    @property
    def conformant(self) -> bool:
        return all(r.status != "FAIL" for r in self.results)

    def add(self, name, status, reason=""):
        self.results.append(TestResult(name, status, reason))

    def print_report(self):
        print("=" * 72)
        print(f"GEMEO Architecture v1.0 conformance — {self.instance}")
        print("=" * 72)
        icons = {"PASS": "✅", "FAIL": "❌", "SKIP": "⊘ "}
        for r in self.results:
            print(f"  {icons.get(r.status, '? ')} {r.name:<38} {r.status:<5} {r.reason}")
        print("-" * 72)
        verdict = "CONFORMANT ✅" if self.conformant else "NON-CONFORMANT ❌"
        n_pass = sum(1 for r in self.results if r.status == "PASS")
        n_fail = sum(1 for r in self.results if r.status == "FAIL")
        n_skip = sum(1 for r in self.results if r.status == "SKIP")
        print(f"  Verdict: {verdict}   ({n_pass} pass / {n_fail} fail / {n_skip} skip)")
        print("=" * 72)


def load_checkpoint(path):
    """Load a checkpoint from a local path or HF repo id."""
    import torch
    if os.path.exists(path):
        return torch.load(path, map_location="cpu", weights_only=False), path
    # Try HF
    try:
        from huggingface_hub import hf_hub_download
        # Common ckpt filenames
        for fname in ["cdf_v13.pt", "cdf_v13_v3.pt", "model.pt", "pytorch_model.bin"]:
            try:
                local = hf_hub_download(repo_id=path, filename=fname)
                return torch.load(local, map_location="cpu", weights_only=False), local
            except Exception:
                continue
    except ImportError:
        pass
    raise FileNotFoundError(f"Could not load checkpoint: {path}")


def run_conformance(ckpt_path, meds_dir=None, readme_text=None):
    import torch
    import numpy as np
    report = ConformanceReport(instance=ckpt_path)
    ckpt, resolved = load_checkpoint(ckpt_path)
    cfg = ckpt.get("config", {})
    vocab = ckpt.get("vocab", [])
    tok2id = ckpt.get("tok2id", {})

    # ---- Test 1: MEDS schema (vocab uses canonical MEDS code prefixes) ----
    MEDS_PREFIXES = ("ICD10//", "SIH//", "APAC//", "SIGTAP//", "BPAI//", "ORPHA", "MEDS_")
    if vocab:
        n_meds = sum(1 for v in vocab if any(v.startswith(p) or p in v for p in MEDS_PREFIXES))
        # Forbidden: CID10// (Brazilian non-canonical) instead of ICD10//
        has_bad = any(v.startswith("CID10//") for v in vocab)
        if has_bad:
            report.add("1. MEDS schema", "FAIL", "found non-canonical CID10// prefix (should be ICD10//)")
        elif n_meds >= 3:
            report.add("1. MEDS schema", "PASS", f"{n_meds} tokens use canonical MEDS prefixes")
        else:
            report.add("1. MEDS schema", "FAIL", f"only {n_meds} MEDS-prefixed tokens")
    else:
        report.add("1. MEDS schema", "SKIP", "no vocab in checkpoint")

    # ---- Test 2: per-token sigma support ----
    try:
        sys.path.insert(0, os.path.join(os.path.dirname(resolved), "src"))
        sys.path.insert(0, os.path.join(os.path.dirname(resolved), "reference_impl"))
        from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config
        c = CDFv13Config(**{k: v for k, v in cfg.items() if k in CDFv13Config.__dataclass_fields__})
        model = CDFv13Transformer(c)
        # Try a forward with a per-token sigma vector
        B, T = 2, min(16, c.max_seq_len)
        x = torch.randint(0, c.vocab_size, (B, T))
        sigma = torch.rand(B, T)  # per-token!
        cond = torch.zeros(B, dtype=torch.long)
        with torch.no_grad():
            out = model(x, sigma, cond)
        if out.shape == (B, T, c.vocab_size):
            report.add("2. Per-token sigma", "PASS", f"accepts (B,T) sigma → logits {tuple(out.shape)}")
        else:
            report.add("2. Per-token sigma", "FAIL", f"unexpected output shape {tuple(out.shape)}")
    except Exception as e:
        report.add("2. Per-token sigma", "SKIP", f"cannot instantiate model: {type(e).__name__}: {str(e)[:50]}")

    # ---- Test 3: KG gating init (alpha = 0) ----
    try:
        kg_layers = cfg.get("kg_attn_layers", [])
        use_kg = cfg.get("use_kg", False)
        if use_kg and kg_layers:
            # Inspect alpha gate params if model loaded
            try:
                model.load_state_dict(ckpt["model_state"])
            except Exception:
                pass
            alphas = [n for n, _ in model.named_parameters() if "alpha" in n.lower()]
            report.add("3. KG gating", "PASS",
                       f"KG cross-attn at layers {kg_layers}; {len(alphas)} gate params present")
        else:
            report.add("3. KG gating", "FAIL", "use_kg=False or no kg_attn_layers")
    except Exception as e:
        report.add("3. KG gating", "SKIP", f"{type(e).__name__}")

    # ---- Test 4: gap-fill recovery (requires MEDS test data) ----
    if meds_dir and os.path.isdir(meds_dir):
        report.add("4. Gap-fill recovery", "SKIP",
                   "MEDS dir present but full gap-fill eval not run in lightweight CLI "
                   "(use benchmark_v4_clinical.py for the full test)")
    else:
        report.add("4. Gap-fill recovery", "SKIP", "no MEDS test dir provided (--meds)")

    # ---- Test 5: bootstrap-then-learn (module structure) ----
    # We check that the architecture exposes the expected inference-mode modules.
    impl_dir = None
    for d in [os.path.join(os.path.dirname(resolved), "src"),
              os.path.join(os.path.dirname(resolved), "reference_impl")]:
        if os.path.isdir(d):
            impl_dir = d; break
    if impl_dir:
        files = set(os.listdir(impl_dir))
        # The world-model core MUST be present; full twin-stack modules are in the app layer repo.
        core_present = "diffusion_forcing_v13.py" in files and "primekg_attention.py" in files
        if core_present:
            report.add("5. Bootstrap-then-learn", "PASS",
                       "world-model core present; full mode modules in gemeo-twin-stack")
        else:
            report.add("5. Bootstrap-then-learn", "FAIL", "missing core modules")
    else:
        report.add("5. Bootstrap-then-learn", "SKIP", "no src/ dir found alongside checkpoint")

    # ---- Test 6: health-system grounding (declared in config or readme) ----
    grounding_terms = ["pcdt", "sus", "formulary", "nice", "cms", "ground_sus", "dispensation"]
    text = (readme_text or "").lower()
    if any(t in text for t in grounding_terms):
        report.add("6. Health-system grounding", "PASS", "grounding referenced in model card")
    else:
        report.add("6. Health-system grounding", "SKIP",
                   "no README provided; grounding is in gemeo-twin-stack ground_sus.py")

    # ---- Test 7: audit citations (model card cites required building blocks) ----
    REQUIRED_CITES = {
        "diffusion forcing": ["diffusion forcing", "2407.01392", "chen"],
        "adaln/dit": ["adaln", "dit", "peebles", "2212.09748"],
        "wsd schedule": ["wsd", "minicpm"],
        "meds": ["meds", "medical event data standard"],
        "primekg": ["primekg", "chandak"],
    }
    if readme_text:
        text = readme_text.lower()
        missing = [k for k, terms in REQUIRED_CITES.items() if not any(t in text for t in terms)]
        if not missing:
            report.add("7. Audit citations", "PASS", "all required building blocks cited")
        else:
            report.add("7. Audit citations", "FAIL", f"missing citations: {', '.join(missing)}")
    else:
        report.add("7. Audit citations", "SKIP", "no README text provided")

    return report


def main():
    ap = argparse.ArgumentParser(description="GEMEO Architecture v1.0 conformance checker")
    sub = ap.add_subparsers(dest="cmd")
    chk = sub.add_parser("check", help="Run conformance tests on a checkpoint")
    chk.add_argument("checkpoint", help="Path to .pt checkpoint or HF repo id")
    chk.add_argument("--meds", default=None, help="Path to MEDS test data dir (for gap-fill test)")
    chk.add_argument("--readme", default=None, help="Path to model card README (for citation/grounding tests)")
    chk.add_argument("--json", action="store_true", help="Emit JSON")
    args = ap.parse_args()

    if args.cmd != "check":
        ap.print_help(); sys.exit(1)

    readme_text = None
    if args.readme and os.path.exists(args.readme):
        readme_text = open(args.readme).read()
    else:
        # Try README.md next to the checkpoint
        cand = os.path.join(os.path.dirname(os.path.abspath(args.checkpoint)), "README.md")
        if os.path.exists(cand):
            readme_text = open(cand).read()

    report = run_conformance(args.checkpoint, meds_dir=args.meds, readme_text=readme_text)
    if args.json:
        print(json.dumps({"instance": report.instance, "conformant": report.conformant,
                          "results": [{"name": r.name, "status": r.status, "reason": r.reason}
                                      for r in report.results]}, indent=2))
    else:
        report.print_report()
    sys.exit(0 if report.conformant else 1)


if __name__ == "__main__":
    main()