timmers commited on
Commit
a0fa886
·
verified ·
1 Parent(s): 553f58d

GEMEO Architecture v1.0 — spec + reference impl + figure

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figure1_gemeo_architecture.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language: [en, pt]
4
+ tags:
5
+ - world-model
6
+ - patient-digital-twin
7
+ - reference-architecture
8
+ - diffusion-forcing
9
+ - meds
10
+ - primekg
11
+ - rare-disease
12
+ library_name: pytorch
13
+ pipeline_tag: time-series-forecasting
14
+ ---
15
+
16
+ # GEMEO Architecture v1.0
17
+
18
+ > **A reference architecture for patient world models.** Six principles,
19
+ > pluggable substrate, three open instances. *Not a model* — a recipe for
20
+ > building a model.
21
+
22
+ ![GEMEO Architecture v1.0](./figure1_gemeo_architecture.png)
23
+
24
+ This repo contains the **architecture specification** and a **reference
25
+ implementation** (Apache-2.0 source, no weights). To use:
26
+
27
+ 1. Read [`gemeo_architecture_spec_v1.md`](./gemeo_architecture_spec_v1.md)
28
+ for the 6-principle conformance definition.
29
+ 2. Copy `reference_impl/` into your repo, adapt to your substrate (any
30
+ MEDS v0.4.1-compliant EHR), train.
31
+ 3. Name your instance `gemeo-<substrate>-v<n>` and submit a conformance
32
+ report.
33
+
34
+ ## Open instances (May 2026)
35
+
36
+ | Instance | Substrate | Params | Status |
37
+ |---|---|---|---|
38
+ | [`Raras-AI/gemeo-sus-v2`](https://huggingface.co/Raras-AI/gemeo-sus-v2) | DATASUS (Brazil, 42K patients) | 19.86M | ✅ released |
39
+ | [`Raras-AI/gemeo-twin-stack`](https://huggingface.co/Raras-AI/gemeo-twin-stack) | application layer | NeuralSurv + heads | ✅ released |
40
+ | `Raras-AI/gemeo-mayo-v3` | Mayo Clinic Platform (planned) | 300M | in proposal |
41
+ | `Raras-AI/gemeo-mimic-demo` | MIMIC-IV-DEMO | reference impl | in progress |
42
+
43
+ ## The six architectural principles
44
+
45
+ 1. **Diffusion Forcing backbone** with per-token σ ∼ 𝒰(0, 1)
46
+ 2. **Gated KG cross-attention** with tanh(α), α init = 0; real PrimeKG edges
47
+ 3. **MEDS v0.4.1 substrate** — `(subject_id, time, code, value)`
48
+ 4. **Bootstrap-then-learn** pattern per inference mode
49
+ 5. **Bidirectional health-system grounding** (formulary re-rank)
50
+ 6. **Audit-driven training** (Chinchilla scaling + SOTA component validation)
51
+
52
+ Full definitions and conformance tests in [`gemeo_architecture_spec_v1.md`](./gemeo_architecture_spec_v1.md).
53
+
54
+ ## Citation
55
+
56
+ ```bibtex
57
+ @misc{gemeo_arch_v1_2026,
58
+ title = {GEMEO Architecture Specification v1.0:
59
+ Reference architecture for patient world models},
60
+ author = {Verdial, Dimas Quintas and Kawassaki, Alexandre and the Raras AI team},
61
+ year = {2026},
62
+ url = {https://huggingface.co/Raras-AI/gemeo-arch},
63
+ }
64
+ ```
65
+
66
+ ⚠️ Research only. Not a medical device. No clinical use.
figure1_gemeo_architecture.png ADDED

Git LFS Details

  • SHA256: 1a564f3310f6da286689358b8164c8447b8522d5f038aaf08f5847e12a3a2786
  • Pointer size: 131 Bytes
  • Size of remote file: 262 kB
gemeo_architecture_spec_v1.md ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "GEMEO Architecture Specification v1.0"
3
+ subtitle: "Reference architecture for patient world models and healthcare digital twins"
4
+ author: "Raras AI"
5
+ date: "May 2026"
6
+ geometry: margin=1in
7
+ fontsize: 11pt
8
+ mainfont: "Inter"
9
+ ---
10
+
11
+ # GEMEO Architecture Specification v1.0
12
+
13
+ > **Status:** Stable. This document defines GEMEO Architecture v1.0 — a reference design for **patient world models** in the lineage of Dreamer (Hafner 2019–2025), Diffusion Forcing (Chen NeurIPS 2024), Sora (OpenAI 2024), and Genie (DeepMind 2024–2025), applied to clinical event streams.
14
+ >
15
+ > **License:** CC-BY-NC 4.0 (specification text). Reference implementation in `Raras-AI/gemeo-arch` is Apache-2.0.
16
+ >
17
+ > **Authors:** Raras AI team. Correspondence: dimas@raras.ai
18
+
19
+ ---
20
+
21
+ ## 0. Scope and motivation
22
+
23
+ A *patient world model* is a learned generative model of patient-trajectory dynamics. It should support: (a) trajectory rollout conditional on actions, (b) counterfactual reasoning under interventions, (c) risk and outcome inference, and (d) drug repurposing. GEMEO Architecture v1.0 is the smallest concrete design we have found that satisfies all four under realistic constraints — sparse clinical events, missing labels, asymmetric data substrates, and regulatory grounding.
24
+
25
+ This document defines the architecture so that any institution holding longitudinal patient data can instantiate **its own GEMEO model** without sharing data with us. Implementations are expected to ship distinct *instances* (e.g. `gemeo-sus-v2`, `gemeo-mayo-v3`, `gemeo-mimic-demo`) while remaining GEMEO-conformant per §6 of this spec.
26
+
27
+ ## 1. The six architectural principles
28
+
29
+ A model is **GEMEO-conformant** if and only if it implements all six principles below. Optional extensions are allowed (§5) but do not affect conformance.
30
+
31
+ ### Principle 1 — Diffusion Forcing backbone with per-token σ
32
+
33
+ The trajectory backbone MUST be a transformer with **per-token noise level** σᵢ ∼ 𝒰(0, 1) sampled independently at each position during training (Chen et al., NeurIPS 2024 [DF]). Score-SDE backbones, pure autoregressive backbones, and uniform-σ backbones are **not** GEMEO-conformant.
34
+
35
+ Rationale: per-token σ enables variable-horizon rollout and gap-filling under one trained model, which is necessary for both the trajectory mode and the bidirectional masked-fill mode used in clinical-evidence reconstruction.
36
+
37
+ ### Principle 2 — Gated cross-attention to a biomedical knowledge graph
38
+
39
+ The model MUST attend to a *real* biomedical knowledge-graph ego-subgraph (PrimeKG [Chandak 2023] or equivalent), via cross-attention with a tanh-gated residual:
40
+
41
+ ```
42
+ h' = h + tanh(α) · CrossAttn(h, KG_ego)
43
+ ```
44
+
45
+ with α initialised to zero. The KG cross-attention MUST be inserted at ≥ 1 layer of the backbone. The KG ego-subgraph MUST contain real disease–gene and disease–phenotype edges; *simulated* or *RAG-only* KG conditioning is not GEMEO-conformant.
46
+
47
+ Rationale: zero-initialised gating (Flamingo [Alayrac 2022], Genie [Bruce 2024]) lets the backbone train as an unconditional model first and adopt KG signal only when it improves the loss, avoiding early-training instability. Real edges anchor the latent state to known biology and prevent hallucinated trajectories under counterfactual rollout.
48
+
49
+ ### Principle 3 — MEDS v0.4.1 substrate
50
+
51
+ Event input MUST be in the **Medical Event Data Standard v0.4.1** [McDermott 2024]. Each event is a tuple `(subject_id, time, code, numeric_value, text_value)`. Code-prefix patterns MUST follow the canonical MEDS conventions: `ICD10//`, `SIH//`, `APAC//`, `SIGTAP//`, `ORPHA//`, with `MEDS_BIRTH` and `MEDS_DEATH` reserved. The MEDS schema MUST be validated at export time via the official `meds.DataSchema`.
52
+
53
+ Rationale: MEDS is the lingua franca for EHR foundation models (EHRSHOT, CoMET, CLMBR, MOTOR all use MEDS-derivatives). Compliance ensures any GEMEO instance can be benchmarked against any MEDS-based competitor and that EHR substrates can be swapped without re-engineering the model.
54
+
55
+ ### Principle 4 — Bootstrap-then-learn pattern per inference mode
56
+
57
+ Each inference mode (trajectory, diagnosis, risk, counterfactual, repurposing, cohort) MUST expose the SAME function signature regardless of whether a learned head exists. When no learned head exists, the function MUST return a deterministic, rule-based or LLM-bootstrap output flagged `model="bootstrap"`. When a learned head exists, the function MUST return the learned output flagged with the head identifier (e.g. `model="neuralsurv"`).
58
+
59
+ This is not a software-engineering nicety. It is a deployment guarantee: a clinic can deploy a GEMEO instance before any GPU-trained head exists. Capabilities turn on monotonically as checkpoints land in the runtime.
60
+
61
+ ### Principle 5 — Bidirectional health-system grounding
62
+
63
+ Therapy recommendations MUST be re-ranked by the patient's actual health-system formulary. The re-ranking score MUST include at minimum:
64
+
65
+ ```
66
+ score_grounding(r, p) = π_formulary(r, dx) · ρ_region(r, p_region) · (1 − dist(p, c_ref) / D_max)
67
+ ```
68
+
69
+ where π_formulary ∈ {0, 1} marks formulary membership (PCDT in Brazil, NICE in UK, CMS-approved in US, etc.), ρ_region is the per-region empirical dispensation rate, and c_ref is the closest specialised referral centre. GEMEO instances that recommend therapies the local health system does not deliver are not conformant.
70
+
71
+ Rationale: clinical AI deployed without grounding to local formularies is not actionable. This principle distinguishes GEMEO from US-only or geography-blind foundation models.
72
+
73
+ ### Principle 6 — Audit-driven training
74
+
75
+ Every architectural decision MUST be auditable against a contemporaneous SOTA reference. Specifically:
76
+
77
+ - **Compute scaling**: parameter count MUST respect Chinchilla scaling for the available training tokens. (Our v1 made the inverse error — 125M params on 20M tokens, 30–60× too large — and was downsized to 19.86M in v2.)
78
+ - **Component validation**: each module (DF objective, KG cross-attention, WSD schedule, etc.) MUST cite the paper that established it as SOTA at training time and MUST be ablatable.
79
+ - **Vocabulary integrity**: token codes MUST match the canonical MEDS prefixes (e.g. `ICD10//`, not `CID10//`). Synthetic or hallucinated KG edges are forbidden.
80
+
81
+ A training run that cannot answer "why this layer, why this size, why this schedule" with a citation does not produce a GEMEO-conformant model.
82
+
83
+ ---
84
+
85
+ ## 2. Required components and module structure
86
+
87
+ A reference implementation MUST expose the following Python module structure:
88
+
89
+ ```
90
+ gemeo/
91
+ ├── __init__.py
92
+ ├── core.py # orchestrator
93
+ ├── api.py # FastAPI surface (/api/gemeo/*)
94
+ ├── types.py # typed dataclasses
95
+
96
+ ├── cdf/ # Principle 1 + 2: the world model itself
97
+ │ ├── diffusion_forcing.py # CDF transformer
98
+ │ ├── adaln_zero.py # DiT-style σ conditioning
99
+ │ ├── primekg_attention.py # gated KG cross-attention
100
+ │ ├── wsd_scheduler.py # WSD LR schedule
101
+ │ ├── meds_export.py # MEDS v0.4.1 schema validation
102
+ │ └── sample.py # inference / rollout
103
+
104
+ ├── encoder.py # static patient embedding (HGT or bootstrap)
105
+ ├── cohort.py # patients-like-mine (Principle 4 bootstrap)
106
+ ├── subgraph.py # KG sparsification
107
+ ├── trajectory.py # trajectory mode (delegates to cdf/)
108
+ ├── risk.py # NeuralSurv head (Principle 4)
109
+ ├── repurpose.py # drug repurposing (TxGNN slot)
110
+ ├── whatif.py # counterfactual engine
111
+ ├── ask.py # active learning
112
+ ├── ground_sus.py # Principle 5: health-system grounding
113
+ └── feedback.py # closed-loop label ingestion
114
+ ```
115
+
116
+ A subset is allowed — any instance MAY omit modes it does not need (e.g. a research-only instance might skip `ground_sus.py`). But if a module IS present, it MUST conform to the signatures in `types.py`.
117
+
118
+ ---
119
+
120
+ ## 3. Training recipe
121
+
122
+ ### 3.1 Required training-time properties
123
+
124
+ - **Per-token σ ∼ 𝒰(0.01, 0.99)** sampled independently. Absorbing-state corruption: position *i* becomes MASK with probability σᵢ.
125
+ - **Loss**: masked cross-entropy over corrupted positions, with **Min-SNR weighting** (Hang 2023) on per-token loss.
126
+ - **Conditional dropout 10%**: replace `cond` with `<NULL>` token for classifier-free guidance support at inference.
127
+ - **WSD LR schedule**: 5% warmup / 80% stable / 15% linear decay (Hu et al. MiniCPM 2024). Cosine schedules are not GEMEO-conformant.
128
+ - **bf16 mixed precision**.
129
+ - **Embedding tying** between input and output projections.
130
+ - **Architecture**: SwiGLU + RMSNorm + RoPE (Llama-style). Standard transformer with LayerNorm + ReLU is not conformant.
131
+
132
+ ### 3.2 Required validation gates
133
+
134
+ Before publication, an instance MUST report:
135
+
136
+ - Validation cross-entropy on a held-out random-σ task.
137
+ - Integrated Calibration Index (ICI) per Austin & Steyerberg.
138
+ - Gap-fill Top-K (K = 1, 3, 5, 10) on a contiguous mid-trajectory mask of size 24.
139
+ - Multi-horizon Top-K decay curve at prefix lengths *k* ∈ {32, 64, 128}.
140
+ - Per-event-class macro-AUROC with explicit honest reporting of horizon limitations (see §5 of the GEMEO/SUS-v2 model paper for the canonical honest format).
141
+ - Per-subgroup fairness check (sex, age band, region).
142
+
143
+ Instances that do not report all six are non-conformant for publication-grade claims.
144
+
145
+ ### 3.3 Compute budget guidance
146
+
147
+ - **Tier-S** (≤ 10M tokens): backbone ≤ 20M params, single H100 < 1 hour, single instance.
148
+ - **Tier-M** (10M–500M tokens): backbone 20–80M params, 8–24 H100-hours.
149
+ - **Tier-L** (500M–10B tokens): backbone 80M–300M params, 100+ H100-hours.
150
+ - **Tier-XL** (10B+ tokens): backbone 300M+ params, 1000+ H100-hours. Mayo / EHRSHOT / multi-modal class.
151
+
152
+ A tier-S instance with claimed tier-XL performance is suspicious by construction.
153
+
154
+ ---
155
+
156
+ ## 4. Conformance tests
157
+
158
+ A model is GEMEO-conformant if it passes the following automated tests, runnable via `gemeo-bench`:
159
+
160
+ ```bash
161
+ gemeo-bench check Raras-AI/your-gemeo-instance
162
+ ```
163
+
164
+ **Required tests:**
165
+
166
+ 1. `test_schema_meds` — model exports MEDS-validatable event streams.
167
+ 2. `test_per_token_sigma` — model accepts per-position σ vector.
168
+ 3. `test_kg_gating_init` — α = 0 at initialisation; sanity-check the gate.
169
+ 4. `test_gap_fill_recovery` — gap [24, 48) Top-10 ≥ 0.50 on test split (a real GEMEO instance, even tiny, recovers gaps far above random).
170
+ 5. `test_bootstrap_paths` — every inference mode returns a value with `model={bootstrap, learned}` identifier.
171
+ 6. `test_health_system_grounding` — at least one therapy recommendation differs by patient region.
172
+ 7. `test_audit_citations` — model card contains citations for: DF, AdaLN-Zero, WSD, MEDS, KG source, gating pattern.
173
+
174
+ The reference test suite is bundled with `Raras-AI/gemeo-arch` and runs in < 60 seconds on CPU for a tiny instance.
175
+
176
+ ---
177
+
178
+ ## 5. Optional extensions (not required for conformance)
179
+
180
+ - **Self Forcing training** (NeurIPS 2025 Spotlight [Self Forcing 2025]) — addresses tail exposure bias. Recommended for v2+ instances.
181
+ - **Positional features** beyond RoPE — explicit age / calendar-year / region embeddings concatenated with the token embedding. Useful when temporal sparsity is high.
182
+ - **Multimodal substrate** — clinical notes via Gemini / Llama-Med phenotype extraction, WES variant tokens, imaging features. These extend the MEDS event vocabulary but should retain MEDS-prefix conventions.
183
+ - **CoMET-style multi-sample inference** — Monte-Carlo aggregation of *n* trajectories at inference. Note: not required, and per the GEMEO/SUS-v2 ablation in §7.6 of the model paper, it did not outperform the one-shot trained protocol at tier-S.
184
+
185
+ ---
186
+
187
+ ## 6. Versioning and naming
188
+
189
+ - The architecture spec is versioned (`v1.0` here). Backward-incompatible changes increment the major version.
190
+ - Model instances are named `gemeo/<substrate>-v<version>` (e.g. `gemeo/sus-v2`, `gemeo/mayo-v3`, `gemeo/mimic-demo-v1`).
191
+ - The reference implementation is `Raras-AI/gemeo-arch` (architecture, no weights). Instances live at `Raras-AI/gemeo-<substrate>-v<n>`.
192
+ - Bundled application layer: `Raras-AI/gemeo-twin-stack` (the six-mode wrapper around any conformant instance).
193
+
194
+ ---
195
+
196
+ ## 7. Validated instances (May 2026)
197
+
198
+ | Instance | Substrate | Params | Status | Reference |
199
+ |---|---|---|---|---|
200
+ | `gemeo-sus-v2` | DATASUS (Brazil, 42 K patients) | 19.86 M | ✅ open | Verdial 2026 |
201
+ | `gemeo-mayo-v3` | Mayo Clinic Platform (planned, 3 M) | 300 M (planned) | in proposal | Mayo Accelerate proposal |
202
+ | `gemeo-mimic-demo` | MIMIC-IV-DEMO (Open, ~100 pts) | reference impl | in progress | — |
203
+
204
+ External instances are welcome. Submit a pull request to `Raras-AI/gemeo-arch` with the conformance-test output to be listed.
205
+
206
+ ---
207
+
208
+ ## 8. Related architectures (not GEMEO-conformant, by design)
209
+
210
+ We name these explicitly so that reviewers and adopters can distinguish:
211
+
212
+ - **Sora** [OpenAI 2024]: world model for video, not patient events. DF lineage but no clinical substrate.
213
+ - **Dreamer 4** [Hafner 2025]: world model for general environments. DF lineage. Not specialised to MEDS, no KG cross-attention, no health-system grounding.
214
+ - **EHRWorld** [arXiv 2602.03569]: autoregressive Qwen-fine-tune on EHR. Not Diffusion Forcing. Not conformant to Principle 1.
215
+ - **CoMET** [Epic 2508.12104]: large-scale generative medical event model. Not Diffusion Forcing — autoregressive only. Not conformant to Principle 1.
216
+ - **MOTOR** [Steinberg 2023]: time-to-event foundation model. Not generative-dynamics. Conformant to substrate (MEDS) but not to Principles 1, 2.
217
+ - **CLMBR-T** [Stanford]: contrastive foundation model. Not generative. Not conformant to Principle 1.
218
+ - **Delphi-2M** [Nature 2025]: autoregressive transformer over UK Biobank + Danish registry. Not Diffusion Forcing. Not conformant to Principle 1.
219
+ - **RareGraph-Synth** [arXiv 2510.06267]: score-SDE diffusion on *synthetic* rare-disease graphs. Not GEMEO-conformant: synthetic substrate (violates Principle 6's audit requirements) and score-SDE not Diffusion Forcing.
220
+
221
+ GEMEO is the architecture that combines (DF + KG cross-attn + MEDS + health-system grounding + bootstrap-then-learn + audit) into a single deployable design. To our knowledge, no prior architecture satisfies all six.
222
+
223
+ ---
224
+
225
+ ## 9. Citation
226
+
227
+ ```bibtex
228
+ @misc{gemeo_arch_v1_2026,
229
+ title = {GEMEO Architecture Specification v1.0:
230
+ Reference architecture for patient world models},
231
+ author = {Verdial, Dimas Quintas and Kawassaki, Alexandre and the Raras AI team},
232
+ year = {2026},
233
+ url = {https://huggingface.co/Raras-AI/gemeo-arch},
234
+ note = {Specification document. Reference implementation Apache-2.0.}
235
+ }
236
+ ```
237
+
238
+ ## 10. References
239
+
240
+ - [DF] Chen, B. et al. *Diffusion Forcing: Next-Token Prediction Meets Full-Sequence Diffusion.* NeurIPS 2024 (arXiv:2407.01392).
241
+ - [Dreamer 4] Hafner, D. et al. *Training Agents Inside of Scalable World Models.* arXiv:2509.24527, Sept 2025.
242
+ - [Sora] OpenAI. *Video generation models as world simulators.* Feb 2024.
243
+ - [Genie] Bruce, J. et al. *Genie: Generative Interactive Environments.* ICML 2024.
244
+ - [PrimeKG] Chandak, P., Huang, K., Zitnik, M. *Building a knowledge graph to enable precision medicine.* Nature Sci Data 2023.
245
+ - [MEDS] McDermott, M. et al. *MEDS: Medical Event Data Standard v0.4.1.* GitHub, 2024.
246
+ - [DiT] Peebles, W. & Xie, S. *Scalable Diffusion Models with Transformers.* ICCV 2023.
247
+ - [Flamingo] Alayrac, J.-B. et al. *Flamingo: a Visual Language Model for Few-Shot Learning.* NeurIPS 2022.
248
+ - [MiniCPM] Hu, S. et al. *MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies.* 2024.
249
+ - [Min-SNR] Hang, T. et al. *Efficient Diffusion Training via Min-SNR Weighting Strategy.* ICCV 2023.
250
+ - [Self Forcing 2025] *Self Forcing: Training Diffusion-Forcing Agents Without Exposure Bias.* arXiv:2506.08009, NeurIPS 2025 Spotlight.
251
+
252
+ ---
253
+
254
+ *End of GEMEO Architecture Specification v1.0.*
reference_impl/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GEMEO-CDF: Causal Diffusion Forcing for clinical trajectories.
2
+
3
+ Three "first in medicine" hooks:
4
+ 1. DIFFUSION FORCING (Chen MIT NeurIPS 2024 → Dreamer 4 Hafner 2025 backbone)
5
+ — independent per-token noise levels unify AR + diffusion + counterfactual
6
+ in ONE loss. Zero clinical port as of May 2026.
7
+
8
+ 2. LATENT ACTION MODEL (Genie / DeepMind 2024)
9
+ — VQ-VAE codebook over (state_t, state_{t+1}) deltas discovers a
10
+ treatment vocabulary without RxNorm/ATC labels. Solves the APAC
11
+ miscoding / sparsity / off-label labelling pain in DATASUS.
12
+
13
+ 3. PROCESS REWARD VERIFIER (o3 / MAI-DxO 2025 pattern)
14
+ — small PRM scores top-K rollouts at inference, returns top-1 +
15
+ uncertainty band. Deliberative trajectory generation, novel in EHR.
16
+
17
+ Modules:
18
+ diffusion_forcing.py — core architecture (per-token noise + block-causal)
19
+ lam.py — Latent Action Model (VQ-VAE codebook)
20
+ train_cdf.py — training loop with diffusion forcing objective
21
+ sample.py — sampling: AR mode / denoise mode / counterfactual
22
+ distill.py — Shortcut Forcing distillation (Dreamer 4)
23
+ prm.py — Process Reward Verifier
24
+ """
25
+ from .diffusion_forcing import CDFTransformer, CDFConfig
26
+ from .lam import LatentActionVQVAE, LAMConfig
27
+ from .train_cdf import train_cdf
28
+
29
+ __all__ = [
30
+ "CDFTransformer", "CDFConfig",
31
+ "LatentActionVQVAE", "LAMConfig",
32
+ "train_cdf",
33
+ ]
reference_impl/adaln_zero.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
2
+
3
+ Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
4
+ PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
5
+
6
+ Why for Diffusion Forcing on EHR:
7
+ - Per-token sigma + global cond/action → per-token (scale, shift, gate)
8
+ - Gates init to zero ⇒ block starts as identity ⇒ no catastrophic init
9
+ - Much better CFG (dropped condition path goes through zero gates,
10
+ not corrupting residual stream)
11
+ - DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
12
+
13
+ We fuse THREE conditioning signals:
14
+ - sigma (B, T) per-token noise level → time_emb (B, T, D)
15
+ - cond (B,) cohort-level treatment id → cond_emb (B, D) → broadcast
16
+ - action(B, T) per-token latent action id → action_emb (B, T, D)
17
+
18
+ Combined into c_t (B, T, D) → ConditioningMLP → 6 modulation tensors
19
+ per block. Each block uses them as:
20
+
21
+ h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
22
+ h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
23
+ """
24
+ from __future__ import annotations
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class AdaLNZeroModulator(nn.Module):
30
+ """Generates per-token (scale, shift, gate) for AdaLN-Zero block.
31
+
32
+ Input: fused conditioning vector c (B, T, d_model).
33
+ Output: 6 tensors of shape (B, T, d_model) each:
34
+ (scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
35
+ """
36
+ def __init__(self, d_model: int):
37
+ super().__init__()
38
+ self.modulator = nn.Sequential(
39
+ nn.SiLU(),
40
+ nn.Linear(d_model, 6 * d_model, bias=True),
41
+ )
42
+ # Zero-init for the gate-producing rows (AdaLN-Zero trick)
43
+ # We zero-init ALL outputs initially; gate stays zero so block is identity
44
+ nn.init.zeros_(self.modulator[-1].weight)
45
+ nn.init.zeros_(self.modulator[-1].bias)
46
+
47
+ def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
48
+ # c: (B, T, d_model)
49
+ out = self.modulator(c) # (B, T, 6*d_model)
50
+ return out.chunk(6, dim=-1)
51
+
52
+
53
+ class AdaLNZeroBlock(nn.Module):
54
+ """Transformer block with AdaLN-Zero modulation.
55
+
56
+ Drop-in replacement for the standard pre-norm block. Reads
57
+ pre-computed modulation tensors and applies them around Attn + MLP.
58
+ """
59
+ def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
60
+ rope=None, kg_xattn=None):
61
+ super().__init__()
62
+ self.d_model = d_model
63
+ self.n_heads = n_heads
64
+ self.head_dim = d_model // n_heads
65
+ self.rope = rope
66
+ self.kg_xattn = kg_xattn
67
+
68
+ self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
69
+ self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
70
+ self.proj = nn.Linear(d_model, d_model, bias=False)
71
+ self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
72
+ self.mlp = nn.Sequential(
73
+ nn.Linear(d_model, ffn, bias=False),
74
+ nn.GELU(),
75
+ nn.Linear(ffn, d_model, bias=False),
76
+ )
77
+ self.dropout = nn.Dropout(dropout)
78
+
79
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
80
+ scale_msa, shift_msa, gate_msa,
81
+ scale_mlp, shift_mlp, gate_mlp,
82
+ kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
83
+ import torch.nn.functional as F
84
+ B, T, D = x.shape
85
+ # MSA branch
86
+ h = self.norm1(x) * (1 + scale_msa) + shift_msa
87
+ qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
88
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
89
+ if self.rope is not None:
90
+ q, k = self.rope(q, k, T)
91
+ out = F.scaled_dot_product_attention(
92
+ q, k, v,
93
+ attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
94
+ dropout_p=self.dropout.p if self.training else 0.0,
95
+ )
96
+ out = out.transpose(1, 2).reshape(B, T, D)
97
+ x = x + gate_msa * self.dropout(self.proj(out))
98
+ # KG cross-attention (between MSA and MLP)
99
+ if self.kg_xattn is not None and kg_ctx is not None:
100
+ x = self.kg_xattn(x, kg_ctx)
101
+ # MLP branch
102
+ h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
103
+ x = x + gate_mlp * self.dropout(self.mlp(h))
104
+ return x
105
+
106
+
107
+ class FusedConditioner(nn.Module):
108
+ """Fuse (sigma, cond, action) into one per-token conditioning vector.
109
+
110
+ Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
111
+ """
112
+ def __init__(self, d_model: int, n_conditions: int, n_actions: int,
113
+ use_action: bool = True):
114
+ super().__init__()
115
+ self.d_model = d_model
116
+ self.use_action = use_action
117
+ # Sigma → sinusoidal embedding
118
+ self.sigma_proj = nn.Sequential(
119
+ nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
120
+ )
121
+ self.cond_emb = nn.Embedding(n_conditions, d_model)
122
+ if use_action:
123
+ self.action_emb = nn.Embedding(n_actions + 1, d_model)
124
+ self.fuse = nn.Sequential(
125
+ nn.SiLU(),
126
+ nn.Linear(d_model, d_model),
127
+ )
128
+
129
+ def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
130
+ import math
131
+ half = self.d_model // 2
132
+ freqs = torch.exp(
133
+ -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
134
+ )
135
+ ang = sigma.float().unsqueeze(-1) * freqs
136
+ emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
137
+ return self.sigma_proj(emb)
138
+
139
+ def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
140
+ action: torch.Tensor | None = None) -> torch.Tensor:
141
+ # sigma (B, T) → time_emb (B, T, D)
142
+ time_emb = self.sinusoidal(sigma)
143
+ # cond (B,) → (B, D) → broadcast to (B, T, D)
144
+ cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
145
+ fused = time_emb + cond_emb
146
+ if self.use_action and action is not None:
147
+ fused = fused + self.action_emb(action)
148
+ return self.fuse(fused)
reference_impl/diffusion_forcing_v13.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.
2
+
3
+ Per the SOTA audit (May 2026):
4
+ - Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
5
+ - Path A adopted: small from-scratch model + KG adapters + MEDS interop
6
+
7
+ Architecture:
8
+ - 12M backbone params (Chinchilla-respecting for ~20M token corpus)
9
+ - d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
10
+ - SwiGLU MLP (ffn:d_model = 2.67)
11
+ - Tied embeddings (saves ~12M at vocab=32k)
12
+ - Dropout 0.1 everywhere (small-data critical)
13
+ - Block-causal attention (Diffusion Forcing)
14
+ - Per-token sigma noise (independent)
15
+ - GATED KG cross-attention (tanh(α)·xattn, α init=0)
16
+ - Layers 4, 6, 7 (3 of 8)
17
+ - Lets model learn to use KG progressively, doesn't disrupt early loss
18
+ - DF objective + LM-aux loss (joint training, paper-grade)
19
+
20
+ Sources audited:
21
+ - CoMET (Aug 2025): tokens-per-param ratio
22
+ - CLMBR (Stanford): adapter pattern for cross-site transfer
23
+ - MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
24
+ - Genie (DeepMind 2024): gated cross-attention pattern
25
+ - SD3 (Esser 2024): AdaLN-Zero zero-init gates
26
+ """
27
+ from __future__ import annotations
28
+ import math
29
+ from dataclasses import dataclass, field
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ @dataclass
37
+ class CDFv13Config:
38
+ # Vocab + sequence
39
+ vocab_size: int = 32768 # MEDS-derived (will be much smaller in practice)
40
+ mask_token: int = 32767
41
+ max_seq_len: int = 512
42
+ block_size: int = 16
43
+ # Architecture (Chinchilla-correct for ~20M tokens)
44
+ d_model: int = 384
45
+ n_heads: int = 6
46
+ n_layers: int = 8
47
+ ffn: int = 1024 # SwiGLU effective; flag below uses 2 projections
48
+ dropout: float = 0.1
49
+ emb_dropout: float = 0.1
50
+ use_swiglu: bool = True
51
+ use_rmsnorm: bool = True
52
+ tie_embeddings: bool = True
53
+ # Diffusion forcing
54
+ cond_dropout: float = 0.10
55
+ # KG conditioning (GATED adapters)
56
+ use_kg: bool = True
57
+ kg_dim: int = 3072
58
+ kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
59
+ # Latent action
60
+ use_latent_action: bool = False # Dropped per audit (concept shaky)
61
+ n_latent_actions: int = 512
62
+ # Conditioning
63
+ n_conditions: int = 64
64
+
65
+
66
+ class RMSNorm(nn.Module):
67
+ """Root-mean-square LayerNorm (LLaMA/Mistral style)."""
68
+ def __init__(self, d: int, eps: float = 1e-6):
69
+ super().__init__()
70
+ self.weight = nn.Parameter(torch.ones(d))
71
+ self.eps = eps
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
75
+ return (norm * self.weight.float()).to(x.dtype)
76
+
77
+
78
+ class SwiGLU(nn.Module):
79
+ """SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
80
+ def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
81
+ super().__init__()
82
+ self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
83
+ self.w_up = nn.Linear(d_in, d_hidden, bias=False)
84
+ self.w_down = nn.Linear(d_hidden, d_in, bias=False)
85
+ self.dropout = nn.Dropout(dropout)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
89
+
90
+
91
+ class RotaryEmbedding(nn.Module):
92
+ """RoPE (Su et al. 2021)."""
93
+ def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
94
+ super().__init__()
95
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
96
+ t = torch.arange(max_seq).float()
97
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
98
+ emb = torch.cat([freqs, freqs], dim=-1)
99
+ self.register_buffer("cos", emb.cos(), persistent=False)
100
+ self.register_buffer("sin", emb.sin(), persistent=False)
101
+
102
+ def forward(self, q, k, seq_len):
103
+ cos = self.cos[:seq_len].to(q.dtype).to(q.device)
104
+ sin = self.sin[:seq_len].to(q.dtype).to(q.device)
105
+ def rot_half(x):
106
+ x1, x2 = x.chunk(2, dim=-1)
107
+ return torch.cat([-x2, x1], dim=-1)
108
+ return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
109
+
110
+
111
+ class PerTokenSigmaEmbed(nn.Module):
112
+ """Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
113
+ def __init__(self, d: int):
114
+ super().__init__()
115
+ self.d = d
116
+ self.proj = nn.Sequential(
117
+ nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
118
+ )
119
+
120
+ def forward(self, sigma: torch.Tensor) -> torch.Tensor:
121
+ half = self.d // 2
122
+ freqs = torch.exp(
123
+ -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
124
+ )
125
+ ang = sigma.float().unsqueeze(-1) * freqs
126
+ emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
127
+ return self.proj(emb)
128
+
129
+
130
+ class GatedKGCrossAttention(nn.Module):
131
+ """Cross-attention to KG ego-subgraph, with GATED output.
132
+
133
+ `tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
134
+ initialized to 0. This means at init the cross-attention contributes
135
+ NOTHING to the residual stream, so the model trains identically to
136
+ no-KG until it discovers KG is useful. Prevents catastrophic loss
137
+ spikes on small data.
138
+
139
+ Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
140
+ """
141
+ def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
142
+ super().__init__()
143
+ self.n_heads = n_heads
144
+ self.head_dim = d_model // n_heads
145
+ # Project KG to d_model (run inline so we don't need separate KGProjector module)
146
+ self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
147
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
148
+ self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
149
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
150
+ self.norm_q = RMSNorm(d_model)
151
+ self.norm_kv = RMSNorm(d_model)
152
+ self.dropout = nn.Dropout(dropout)
153
+ # Gate (scalar per block, init=0)
154
+ self.alpha = nn.Parameter(torch.zeros(1))
155
+
156
+ def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
157
+ """
158
+ x_seq: (B, T, d_model)
159
+ kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072)
160
+ """
161
+ B, T, D = x_seq.shape
162
+ kg_proj = self.kg_in_proj(kg_raw) # (B, N_kg, D)
163
+ N_kg = kg_proj.size(1)
164
+ q = self.q_proj(self.norm_q(x_seq))
165
+ kv = self.kv_proj(self.norm_kv(kg_proj))
166
+ k, v = kv.chunk(2, dim=-1)
167
+ q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
168
+ k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
169
+ v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
170
+ out = F.scaled_dot_product_attention(
171
+ q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
172
+ out = out.transpose(1, 2).reshape(B, T, D)
173
+ gate = torch.tanh(self.alpha)
174
+ return x_seq + gate * self.dropout(self.out_proj(out))
175
+
176
+
177
+ class CDFv13Block(nn.Module):
178
+ """Pre-norm transformer block + optional gated KG cross-attn."""
179
+ def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
180
+ layer_idx: int):
181
+ super().__init__()
182
+ self.cfg = cfg
183
+ self.rope = rope
184
+ self.layer_idx = layer_idx
185
+ norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
186
+ self.norm1 = norm_cls(cfg.d_model)
187
+ self.norm2 = norm_cls(cfg.d_model)
188
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
189
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
190
+ if cfg.use_swiglu:
191
+ self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
192
+ else:
193
+ self.mlp = nn.Sequential(
194
+ nn.Linear(cfg.d_model, cfg.ffn, bias=False),
195
+ nn.GELU(),
196
+ nn.Linear(cfg.ffn, cfg.d_model, bias=False),
197
+ nn.Dropout(cfg.dropout),
198
+ )
199
+ self.dropout = nn.Dropout(cfg.dropout)
200
+ self.head_dim = cfg.d_model // cfg.n_heads
201
+ # Gated KG cross-attention (only in specified layers)
202
+ self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
203
+ if self.use_kg_in_layer:
204
+ self.kg_xattn = GatedKGCrossAttention(
205
+ cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
206
+
207
+ def forward(self, x, attn_mask, kg_raw=None):
208
+ B, T, D = x.shape
209
+ # MSA
210
+ h = self.norm1(x)
211
+ qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
212
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
213
+ q, k = self.rope(q, k, T)
214
+ out = F.scaled_dot_product_attention(
215
+ q, k, v,
216
+ attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
217
+ dropout_p=self.cfg.dropout if self.training else 0.0,
218
+ )
219
+ out = out.transpose(1, 2).reshape(B, T, D)
220
+ x = x + self.dropout(self.proj(out))
221
+ # Gated KG cross-attn (if enabled at this layer)
222
+ if self.use_kg_in_layer and kg_raw is not None:
223
+ x = self.kg_xattn(x, kg_raw)
224
+ # MLP
225
+ x = x + self.mlp(self.norm2(x))
226
+ return x
227
+
228
+
229
+ class CDFv13Transformer(nn.Module):
230
+ """Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""
231
+
232
+ def __init__(self, cfg: CDFv13Config | None = None):
233
+ super().__init__()
234
+ self.cfg = cfg or CDFv13Config()
235
+ c = self.cfg
236
+ norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm
237
+
238
+ self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
239
+ self.emb_dropout = nn.Dropout(c.emb_dropout)
240
+
241
+ # Per-token sigma embedding (additive)
242
+ self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
243
+ # Global condition embedding (additive, broadcast)
244
+ self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
245
+
246
+ # RoPE
247
+ self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)
248
+
249
+ # Blocks
250
+ self.blocks = nn.ModuleList([
251
+ CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
252
+ ])
253
+ self.final_norm = norm_cls(c.d_model)
254
+ self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
255
+ if c.tie_embeddings:
256
+ self.head.weight = self.tok_emb.weight
257
+
258
+ # Block-causal mask buffer
259
+ T = c.max_seq_len
260
+ block_id = torch.arange(T) // c.block_size
261
+ mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
262
+ self.register_buffer("block_mask", mask, persistent=False)
263
+
264
+ # Init
265
+ self.apply(self._init_weights)
266
+
267
+ def _init_weights(self, m):
268
+ if isinstance(m, nn.Linear):
269
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
270
+ if m.bias is not None: nn.init.zeros_(m.bias)
271
+ elif isinstance(m, nn.Embedding):
272
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
273
+
274
+ def forward(self, x, sigma, cond, kg_raw=None):
275
+ B, T = x.shape
276
+ h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
277
+ h = self.emb_dropout(h)
278
+ mask = self.block_mask[:T, :T]
279
+ for blk in self.blocks:
280
+ h = blk(h, mask, kg_raw=kg_raw)
281
+ h = self.final_norm(h)
282
+ return self.head(h)
283
+
284
+ def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
285
+ mode: str = "uniform") -> torch.Tensor:
286
+ """Standard absorbing-state DF loss with per-token sigma.
287
+
288
+ mode: 'uniform' (default — safer for discrete than logit-normal per audit)
289
+ 'logit_normal' (SD3-style — keep as ablation only)
290
+ """
291
+ B, T = x_clean.shape
292
+ device = x_clean.device
293
+ # CFG cond dropout
294
+ drop = torch.rand(B, device=device) < self.cfg.cond_dropout
295
+ cond = torch.where(drop, torch.zeros_like(cond), cond)
296
+ if kg_raw is not None:
297
+ drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
298
+ kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
299
+ # Sample per-token sigma
300
+ if mode == "logit_normal":
301
+ sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
302
+ else:
303
+ sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
304
+ # Absorbing-state corruption
305
+ corrupt = torch.rand(B, T, device=device) < sigma
306
+ x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
307
+ logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
308
+ ce = F.cross_entropy(
309
+ logits.reshape(-1, self.cfg.vocab_size),
310
+ x_clean.reshape(-1),
311
+ reduction="none",
312
+ ).reshape(B, T)
313
+ n = corrupt.float().sum().clamp(min=1.0)
314
+ return (ce * corrupt.float()).sum() / n
reference_impl/eval_sota.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SOTA evaluation suite for CDFv13 — audit-proof.
2
+
3
+ Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
4
+ with the canonical EHR foundation model metric stack:
5
+
6
+ Classification (next-event, downstream tasks):
7
+ - AUROC + AUPRC + Brier
8
+ - Calibration: ICI (Austin & Steyerberg 2019)
9
+ - Decision-curve analysis (Vickers)
10
+ - Bootstrap 95% CI (≥2000 resamples) — required for rare disease
11
+
12
+ Survival (DATASUS SIM mortality):
13
+ - Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
14
+ - Integrated Brier Score (1/3/5y)
15
+ - Time-dependent AUC
16
+
17
+ Counterfactual / causal:
18
+ - ATE with bootstrap CI
19
+ - E-value (VanderWeele)
20
+ - Negative-control outcome + exposure
21
+ - Tipping-point analysis
22
+
23
+ Generation fidelity (CoMET / SynthEHRella):
24
+ - Dim-wise probability match
25
+ - MMD (Maximum Mean Discrepancy) with RBF kernel
26
+ - TSTR (Train-on-Synthetic-Test-on-Real)
27
+
28
+ Subgroup fairness (npj DM requirement):
29
+ - Stratified metrics: sex, age band, UF region
30
+
31
+ Split strategy (DATASUS rare disease):
32
+ - Temporal: train ≤2022, val 2023, test 2024-2025
33
+ - Geographic: train SE+S, test N+NE (UF cross-region = "external")
34
+ - Patient-level 5-fold CV (variance estimation)
35
+ """
36
+ from __future__ import annotations
37
+ import math
38
+ import numpy as np
39
+ import torch
40
+ from typing import Callable
41
+
42
+
43
+ # ---------- Classification ----------
44
+
45
+ def auroc(y: np.ndarray, p: np.ndarray) -> float:
46
+ from sklearn.metrics import roc_auc_score
47
+ if len(np.unique(y)) < 2: return float("nan")
48
+ return roc_auc_score(y, p)
49
+
50
+
51
+ def auprc(y: np.ndarray, p: np.ndarray) -> float:
52
+ from sklearn.metrics import average_precision_score
53
+ if len(np.unique(y)) < 2: return float("nan")
54
+ return average_precision_score(y, p)
55
+
56
+
57
+ def brier(y: np.ndarray, p: np.ndarray) -> float:
58
+ from sklearn.metrics import brier_score_loss
59
+ return brier_score_loss(y, p)
60
+
61
+
62
+ def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
63
+ """Integrated Calibration Index (Austin & Steyerberg 2019).
64
+ Lowess-smoothed deviation from perfect calibration.
65
+ """
66
+ from statsmodels.nonparametric.smoothers_lowess import lowess
67
+ sm = lowess(y, p, frac=frac, return_sorted=True)
68
+ return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))
69
+
70
+
71
+ def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
72
+ """Net benefit at a given decision threshold (Vickers DCA)."""
73
+ tp = ((p >= threshold) & (y == 1)).sum()
74
+ fp = ((p >= threshold) & (y == 0)).sum()
75
+ n = len(y)
76
+ if threshold >= 1.0: return 0.0
77
+ return tp / n - (fp / n) * (threshold / (1 - threshold))
78
+
79
+
80
+ def decision_curve(y: np.ndarray, p: np.ndarray,
81
+ thresholds: list[float] = None) -> dict:
82
+ """Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
83
+ if thresholds is None:
84
+ thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
85
+ model_nb = [net_benefit(y, p, t) for t in thresholds]
86
+ treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
87
+ for t in thresholds]
88
+ treat_none_nb = [0.0] * len(thresholds)
89
+ return {
90
+ "thresholds": thresholds,
91
+ "model": model_nb,
92
+ "treat_all": treat_all_nb,
93
+ "treat_none": treat_none_nb,
94
+ }
95
+
96
+
97
+ def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
98
+ n_boot: int = 2000, seed: int = 0,
99
+ ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
100
+ """Bootstrap 95% CI for any (y, p) -> scalar metric."""
101
+ rng = np.random.default_rng(seed)
102
+ n = len(y)
103
+ stats = []
104
+ for _ in range(n_boot):
105
+ idx = rng.integers(0, n, n)
106
+ if len(np.unique(y[idx])) < 2: continue
107
+ try:
108
+ stats.append(metric_fn(y[idx], p[idx]))
109
+ except Exception:
110
+ continue
111
+ if not stats: return (float("nan"),) * 3
112
+ return (
113
+ float(np.percentile(stats, ci[0])),
114
+ float(np.median(stats)),
115
+ float(np.percentile(stats, ci[1])),
116
+ )
117
+
118
+
119
+ # ---------- Survival ----------
120
+
121
+ def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
122
+ risk_score, tau: float = None) -> float:
123
+ """Uno's C-index (IPCW concordance), preferred at high censoring.
124
+ Requires scikit-survival.
125
+ """
126
+ try:
127
+ from sksurv.metrics import concordance_index_ipcw
128
+ except ImportError:
129
+ return float("nan")
130
+ # Build structured arrays
131
+ y_train = np.array(
132
+ list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
133
+ dtype=[("event", "?"), ("time", "<f8")],
134
+ )
135
+ y_test = np.array(
136
+ list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
137
+ dtype=[("event", "?"), ("time", "<f8")],
138
+ )
139
+ if tau is None:
140
+ tau = float(y_test_time.max()) * 0.95
141
+ c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
142
+ return float(c)
143
+
144
+
145
+ def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
146
+ surv_pred: np.ndarray, times: np.ndarray) -> float:
147
+ """Integrated Brier Score (lower is better)."""
148
+ try:
149
+ from sksurv.metrics import integrated_brier_score as ibs_fn
150
+ except ImportError:
151
+ return float("nan")
152
+ y_train = np.array(
153
+ list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
154
+ dtype=[("event", "?"), ("time", "<f8")],
155
+ )
156
+ y_test = np.array(
157
+ list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
158
+ dtype=[("event", "?"), ("time", "<f8")],
159
+ )
160
+ return float(ibs_fn(y_train, y_test, surv_pred, times))
161
+
162
+
163
+ # ---------- Causal / Counterfactual ----------
164
+
165
+ def e_value(rr: float) -> float:
166
+ """E-value (VanderWeele & Ding 2017): min strength of unmeasured
167
+ confounder needed to explain away an observed RR.
168
+ """
169
+ rr = max(rr, 1e-9)
170
+ if rr >= 1.0:
171
+ return rr + math.sqrt(rr * (rr - 1))
172
+ rr_inv = 1.0 / rr
173
+ return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
174
+
175
+
176
+ def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
177
+ """Negative-control outcome: ATE on a control outcome should be ~0."""
178
+ return abs(nc_ate) < threshold
179
+
180
+
181
+ def tipping_point(observed_effect: float, ci_half_width: float) -> float:
182
+ """How much would unmeasured confounding need to shift effect to nullify?"""
183
+ if abs(observed_effect) <= ci_half_width:
184
+ return 0.0
185
+ return float(abs(observed_effect) - ci_half_width)
186
+
187
+
188
+ # ---------- Generation fidelity (SynthEHRella triad) ----------
189
+
190
+ def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
191
+ vocab_size: int) -> float:
192
+ """Compare per-token Bernoulli rates between real and synthetic batches.
193
+
194
+ Returns mean abs difference (lower = closer match).
195
+ """
196
+ real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
197
+ synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
198
+ return float((real_one_hot - synth_one_hot).abs().mean())
199
+
200
+
201
+ def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
202
+ """Maximum Mean Discrepancy with RBF kernel.
203
+
204
+ x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
205
+ """
206
+ def rbf(a, b):
207
+ d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
208
+ return torch.exp(-d / (2 * sigma ** 2))
209
+ return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())
210
+
211
+
212
+ # ---------- Subgroup fairness ----------
213
+
214
+ def stratified_metrics(y: np.ndarray, p: np.ndarray,
215
+ groups: np.ndarray,
216
+ metric_fn: Callable = auroc) -> dict[str, float]:
217
+ """Compute metric per subgroup (sex, age band, UF region)."""
218
+ out = {}
219
+ for g in np.unique(groups):
220
+ mask = groups == g
221
+ if mask.sum() > 10:
222
+ try:
223
+ out[str(g)] = metric_fn(y[mask], p[mask])
224
+ except Exception:
225
+ out[str(g)] = float("nan")
226
+ return out
227
+
228
+
229
+ # ---------- DATASUS split strategies ----------
230
+
231
+ def temporal_split(events: list[dict], train_until: int = 2022,
232
+ val_year: int = 2023):
233
+ """Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
234
+ train, val, test = [], [], []
235
+ for e in events:
236
+ y = e.get("year") or 2020
237
+ if y <= train_until: train.append(e)
238
+ elif y == val_year: val.append(e)
239
+ else: test.append(e)
240
+ return train, val, test
241
+
242
+
243
+ def geographic_split(patients: list[dict], external_ufs: set = None):
244
+ """Geographic split: train on SE+S, test on N+NE.
245
+ For DATASUS this is the closest analog to "external validation."
246
+ """
247
+ if external_ufs is None:
248
+ external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
249
+ "PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
250
+ train, test = [], []
251
+ for p in patients:
252
+ uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
253
+ None)
254
+ (test if uf in external_ufs else train).append(p)
255
+ return train, test
256
+
257
+
258
+ # ---------- Combined eval report ----------
259
+
260
+ def full_eval_report(y: np.ndarray, p: np.ndarray,
261
+ groups_sex: np.ndarray = None,
262
+ groups_age: np.ndarray = None,
263
+ groups_uf: np.ndarray = None,
264
+ n_boot: int = 2000) -> dict:
265
+ """Generate a full audit-proof report for a binary classification task.
266
+
267
+ Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
268
+ """
269
+ import torch.nn.functional as F # local import to keep top clean
270
+
271
+ auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
272
+ auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
273
+ brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)
274
+
275
+ report = {
276
+ "n_eval": len(y),
277
+ "prevalence": float(y.mean()),
278
+ "auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
279
+ "auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
280
+ "brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
281
+ "ici": ici(y, p),
282
+ "decision_curve": decision_curve(y, p),
283
+ }
284
+ if groups_sex is not None:
285
+ report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
286
+ if groups_age is not None:
287
+ report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
288
+ if groups_uf is not None:
289
+ report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
290
+ return report
reference_impl/meds_export.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDS v0.4.1 exporter for DATASUS — audit-proof, interop-ready.
2
+
3
+ Verified against:
4
+ - meds 0.4.1 schemas (DataSchema, CodeMetadataSchema)
5
+ - https://github.com/Medical-Event-Data-Standard/meds
6
+ - CLMBR/MOTOR/EHRSHOT/CoMET tokenization conventions
7
+
8
+ Code conventions (interop-compatible):
9
+ - Static (time=None): GENDER//, RACE//, UF//, MUN//, ORPHA//
10
+ - Birth/Death: MEDS_BIRTH, MEDS_DEATH (reserved)
11
+ - Diagnoses: ICD10//<cid> (NOT CID10// — interop with OHDSI/Athena)
12
+ - Hospitalization: SIH//ADM, SIH//DIS (numeric_value=LOS_days on DIS)
13
+ - Procedures: SIGTAP//<10-digit> (Brazil-local namespace)
14
+ - Drugs (APAC): APAC//<sigtap> (numeric_value=monthly_cost_brl)
15
+ - Outpatient (BPA-I): BPAI//<sigtap>
16
+ - Visits: Visit//{IP, OP, ER} (matches CLMBR convention)
17
+
18
+ Outputs canonical MEDS dataset:
19
+ /out/
20
+ ├── data/ # parquet shards by subject
21
+ │ ├── shard_0.parquet
22
+ │ └── ...
23
+ ├── metadata/
24
+ │ └── codes.parquet # REQUIRED: every unique code with description + parent_codes
25
+ └── dataset_metadata.json # MEDS dataset metadata
26
+ """
27
+ from __future__ import annotations
28
+ import os
29
+ import json
30
+ import logging
31
+ from collections import defaultdict, Counter
32
+ from datetime import datetime
33
+ from typing import Iterator
34
+
35
+ import pyarrow as pa
36
+ import pyarrow.parquet as pq
37
+ import meds
38
+
39
+ log = logging.getLogger("gemeo.cdf.meds_export")
40
+
41
+
42
+ def _parse_date(s) -> datetime | None:
43
+ """Parse date string from various DATASUS formats."""
44
+ if s is None: return None
45
+ s = str(s).strip()
46
+ if not s or s in ("0", "None", "nan"): return None
47
+ try:
48
+ if "-" in s:
49
+ return datetime.strptime(s[:10], "%Y-%m-%d")
50
+ if len(s) == 8:
51
+ return datetime.strptime(s, "%Y%m%d")
52
+ except ValueError:
53
+ return None
54
+ return None
55
+
56
+
57
+ def _ym(year, month) -> datetime | None:
58
+ if year is None: return None
59
+ try:
60
+ return datetime(int(year), int(month) if month else 1, 1)
61
+ except (ValueError, TypeError):
62
+ return None
63
+
64
+
65
+ def datasus_patient_to_meds_rows(p: dict, subject_id: int) -> list[tuple]:
66
+ """Convert one DATASUS patient trajectory to a list of MEDS rows.
67
+
68
+ Each row is (subject_id, time, code, numeric_value, text_value).
69
+ Returns rows ready to write to a parquet shard.
70
+ """
71
+ rows = []
72
+
73
+ # ---- Static (time=None) ----
74
+ if p.get("sex"):
75
+ rows.append((subject_id, None, f"GENDER//{p['sex']}", None, None))
76
+ # ORPHA is rare-disease specific (parallel to ICD10)
77
+ for orpha in p.get("orphas", []):
78
+ rows.append((subject_id, None, f"ORPHA//{orpha}", None, None))
79
+
80
+ # ---- Birth (use birth_year as Jan 1) ----
81
+ birth_year = p.get("birth_year")
82
+ birth_dt = datetime(int(birth_year), 1, 1) if birth_year else None
83
+ if birth_dt:
84
+ rows.append((subject_id, birth_dt, "MEDS_BIRTH", None, None))
85
+
86
+ # ---- Events ----
87
+ for e in p.get("events", []):
88
+ et = e.get("type")
89
+
90
+ if et == "admission": # SIH-RD
91
+ t = _ym(e.get("year"), e.get("month")) or _parse_date(e.get("admission_date"))
92
+ if not t: continue
93
+ rows.append((subject_id, t, "SIH//ADM", None, None))
94
+ rows.append((subject_id, t, "Visit//IP", None, None))
95
+ cid = e.get("cid_princ", "")
96
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
97
+ proc = e.get("primary_procedure")
98
+ if proc: rows.append((subject_id, t, f"SIGTAP//{proc[:10]}", None, None))
99
+ los = e.get("los_days")
100
+ disch_dt = _parse_date(e.get("discharge_date")) or t
101
+ if e.get("death_during_stay"):
102
+ rows.append((subject_id, disch_dt, "MEDS_DEATH", None, None))
103
+ else:
104
+ rows.append((subject_id, disch_dt, "SIH//DIS",
105
+ float(los) if los is not None else None, None))
106
+
107
+ elif et == "treatment": # APAC-SIA orphan drug
108
+ t = _ym(e.get("year"), e.get("month"))
109
+ if not t: continue
110
+ cid = e.get("cid", "")
111
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
112
+ proc = e.get("procedure_code", "")[:10]
113
+ if proc:
114
+ cost = e.get("monthly_cost_brl")
115
+ rows.append((subject_id, t, f"APAC//{proc}",
116
+ float(cost) if cost is not None else None, None))
117
+
118
+ elif et == "outpatient_proc": # BPA-I
119
+ t = _parse_date(e.get("auth_date")) or _ym(e.get("year"), e.get("month"))
120
+ if not t: continue
121
+ cid = e.get("cid", "")
122
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
123
+ proc = e.get("procedure_code", "")[:10]
124
+ if proc:
125
+ rows.append((subject_id, t, f"BPAI//{proc}", None, None))
126
+
127
+ elif et == "death": # SIM
128
+ t = _parse_date(e.get("date_of_death")) or _ym(e.get("year"), e.get("month"))
129
+ if not t: continue
130
+ rows.append((subject_id, t, "MEDS_DEATH", None, None))
131
+ cid = (e.get("cause_cid") or e.get("cid_princ") or e.get("cid", ""))
132
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
133
+
134
+ # Sort: nulls first (static), then by time
135
+ rows.sort(key=lambda r: (r[1] is not None, r[1] or datetime(1900, 1, 1)))
136
+ return rows
137
+
138
+
139
+ def export_to_meds(patients: list[dict], out_dir: str,
140
+ shard_size: int = 5000,
141
+ dataset_name: str = "GEMEO-DATASUS",
142
+ version: str = "v13"):
143
+ """Export a list of DATASUS patient trajectories to MEDS v0.4.1 format.
144
+
145
+ Parameters
146
+ ----------
147
+ patients : list of dict
148
+ Each dict must have: patient_id, sex, birth_year, orphas (list),
149
+ events (list of dicts with 'type', 'year', 'month', etc.)
150
+ out_dir : str
151
+ Output directory (will create data/ and metadata/ subdirs)
152
+ shard_size : int
153
+ Number of subjects per parquet shard
154
+ """
155
+ os.makedirs(f"{out_dir}/data", exist_ok=True)
156
+ os.makedirs(f"{out_dir}/metadata", exist_ok=True)
157
+
158
+ log.info(f"Exporting {len(patients)} patients to MEDS at {out_dir}")
159
+
160
+ # Map patient_id (string hash) → int64 subject_id (MEDS requires int64)
161
+ pid_to_sid = {p["patient_id"]: i for i, p in enumerate(patients)}
162
+
163
+ # ---- Stream rows ----
164
+ all_codes = Counter()
165
+ shard_idx = 0
166
+ shard_rows = []
167
+ n_events = 0
168
+ n_subjects = 0
169
+
170
+ for p in patients:
171
+ sid = pid_to_sid[p["patient_id"]]
172
+ rows = datasus_patient_to_meds_rows(p, sid)
173
+ shard_rows.extend(rows)
174
+ n_events += len(rows)
175
+ n_subjects += 1
176
+ for r in rows:
177
+ all_codes[r[2]] += 1
178
+ # Write shard when full
179
+ if n_subjects % shard_size == 0 and shard_rows:
180
+ _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
181
+ shard_idx += 1
182
+ shard_rows = []
183
+
184
+ # Write remaining
185
+ if shard_rows:
186
+ _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
187
+
188
+ log.info(f" wrote {shard_idx + 1} data shards, {n_events} rows, {n_subjects} subjects")
189
+
190
+ # ---- codes.parquet (REQUIRED in MEDS v0.4) ----
191
+ code_rows = []
192
+ for code, count in all_codes.most_common():
193
+ # parent_codes: empty for Brazil-local namespaces; populated for ICD10 -> SNOMED if mapped
194
+ parent_codes = _get_parent_codes(code)
195
+ code_rows.append({
196
+ "code": code,
197
+ "description": _get_description(code, count),
198
+ "parent_codes": parent_codes,
199
+ })
200
+ code_table = pa.Table.from_pylist(code_rows, schema=meds.CodeMetadataSchema.schema())
201
+ pq.write_table(code_table, f"{out_dir}/metadata/codes.parquet")
202
+ log.info(f" wrote metadata/codes.parquet ({len(code_rows)} unique codes)")
203
+
204
+ # ---- dataset_metadata.json ----
205
+ md = {
206
+ "dataset_name": dataset_name,
207
+ "dataset_version": version,
208
+ "etl_name": "gemeo.cdf.meds_export",
209
+ "etl_version": "1.0.0",
210
+ "meds_version": meds.__version__,
211
+ "n_subjects": n_subjects,
212
+ "n_events": n_events,
213
+ "n_unique_codes": len(all_codes),
214
+ "top_codes": dict(all_codes.most_common(30)),
215
+ }
216
+ with open(f"{out_dir}/dataset_metadata.json", "w") as f:
217
+ json.dump(md, f, indent=2, default=str)
218
+ log.info(f" wrote dataset_metadata.json")
219
+
220
+ return md
221
+
222
+
223
+ def _write_shard(rows: list[tuple], path: str):
224
+ """Write a list of (subject_id, time, code, numeric_value, text_value) to parquet."""
225
+ if not rows: return
226
+ # Build columnar arrays
227
+ subject_id = pa.array([r[0] for r in rows], type=pa.int64())
228
+ time = pa.array([r[1] for r in rows], type=pa.timestamp("us"))
229
+ code = pa.array([r[2] for r in rows], type=pa.string())
230
+ numeric_value = pa.array([r[3] for r in rows], type=pa.float32())
231
+ text_value = pa.array([r[4] for r in rows], type=pa.large_string())
232
+ table = pa.Table.from_arrays(
233
+ [subject_id, time, code, numeric_value, text_value],
234
+ names=["subject_id", "time", "code", "numeric_value", "text_value"],
235
+ )
236
+ # Validate against MEDS schema
237
+ expected_schema = meds.DataSchema.schema()
238
+ # Cast if needed
239
+ table = table.cast(expected_schema, safe=False)
240
+ pq.write_table(table, path, compression="zstd")
241
+
242
+
243
+ # Brazilian-specific mapping tables (extend as needed)
244
+ ICD10_CHAPTERS = {
245
+ "A": "Certain infectious and parasitic diseases",
246
+ "B": "Certain infectious and parasitic diseases",
247
+ "C": "Neoplasms",
248
+ "D": "Neoplasms / Diseases of the blood and immune",
249
+ "E": "Endocrine, nutritional and metabolic diseases",
250
+ "F": "Mental, Behavioral and Neurodevelopmental disorders",
251
+ "G": "Diseases of the nervous system",
252
+ "H": "Diseases of the eye / ear",
253
+ "I": "Diseases of the circulatory system",
254
+ "J": "Diseases of the respiratory system",
255
+ "K": "Diseases of the digestive system",
256
+ "L": "Diseases of the skin and subcutaneous tissue",
257
+ "M": "Diseases of the musculoskeletal system",
258
+ "N": "Diseases of the genitourinary system",
259
+ "O": "Pregnancy, childbirth and the puerperium",
260
+ "P": "Certain conditions originating in the perinatal period",
261
+ "Q": "Congenital malformations, deformations and chromosomal abnormalities",
262
+ "R": "Symptoms, signs and abnormal clinical and laboratory findings",
263
+ "S": "Injury, poisoning and certain other consequences of external causes",
264
+ "T": "Injury, poisoning and certain other consequences of external causes",
265
+ "V": "External causes of morbidity",
266
+ "W": "External causes of morbidity",
267
+ "X": "External causes of morbidity",
268
+ "Y": "External causes of morbidity",
269
+ "Z": "Factors influencing health status and contact with health services",
270
+ }
271
+
272
+
273
+ def _get_description(code: str, count: int) -> str:
274
+ """Generate a brief description for a code (used in codes.parquet)."""
275
+ if code in ("MEDS_BIRTH",): return "Birth event (reserved)"
276
+ if code in ("MEDS_DEATH",): return "Death event (reserved)"
277
+ parts = code.split("//")
278
+ if len(parts) < 2: return f"Unknown code (n={count})"
279
+ domain, val = parts[0], "//".join(parts[1:])
280
+ if domain == "GENDER": return f"Patient sex = {val}"
281
+ if domain == "ORPHA": return f"Orphanet rare disease {val}"
282
+ if domain == "ICD10":
283
+ ch = ICD10_CHAPTERS.get(val[0], "Unknown chapter")
284
+ return f"ICD-10 {val} ({ch})"
285
+ if domain == "SIH": return f"SIH hospitalization {val}"
286
+ if domain == "Visit": return f"Visit type {val}"
287
+ if domain == "SIGTAP": return f"SIGTAP procedure {val}"
288
+ if domain == "APAC": return f"APAC orphan-drug authorization {val}"
289
+ if domain == "BPAI": return f"BPA-I outpatient procedure {val}"
290
+ if domain == "UF": return f"Residence UF {val}"
291
+ return f"{domain} code {val}"
292
+
293
+
294
+ def _get_parent_codes(code: str) -> list[str]:
295
+ """Return parent codes for ontology hierarchy (currently minimal)."""
296
+ parts = code.split("//")
297
+ if len(parts) < 2: return []
298
+ domain, val = parts[0], "//".join(parts[1:])
299
+ parents = []
300
+ if domain == "ICD10" and len(val) >= 3:
301
+ # ICD-10 chapter as parent
302
+ chapter = val[0]
303
+ if chapter in ICD10_CHAPTERS:
304
+ parents.append(f"ICD10//chapter_{chapter}")
305
+ # 3-char prefix as parent (e.g., E84.0 → E84)
306
+ if "." in val:
307
+ parents.append(f"ICD10//{val.split('.')[0]}")
308
+ elif len(val) > 3:
309
+ parents.append(f"ICD10//{val[:3]}")
310
+ if domain == "SIGTAP" and len(val) >= 4:
311
+ # 4-digit group as parent (SIGTAP 10-digit → 4-digit group)
312
+ parents.append(f"SIGTAP//group_{val[:4]}")
313
+ return parents
314
+
315
+
316
+ def load_meds_dataset(meds_dir: str) -> dict:
317
+ """Load a MEDS dataset back from parquet for inspection or downstream processing."""
318
+ import glob
319
+ shards = sorted(glob.glob(f"{meds_dir}/data/*.parquet"))
320
+ tables = [pq.read_table(p) for p in shards]
321
+ data = pa.concat_tables(tables) if tables else None
322
+ codes = pq.read_table(f"{meds_dir}/metadata/codes.parquet")
323
+ md = json.load(open(f"{meds_dir}/dataset_metadata.json"))
324
+ return {"data": data, "codes": codes, "metadata": md}
325
+
326
+
327
+ if __name__ == "__main__":
328
+ # Quick test on real patient data
329
+ logging.basicConfig(level=logging.INFO,
330
+ format="%(asctime)s %(levelname)s %(message)s")
331
+ PATIENTS = "/tmp/datasus_patient_trajectories_v2.json"
332
+ if os.path.exists(PATIENTS):
333
+ patients = json.load(open(PATIENTS))[:50] # 50 patients smoke test
334
+ md = export_to_meds(patients, "/tmp/meds_smoke_test")
335
+ print("\n=== smoke test result ===")
336
+ print(json.dumps(md, indent=2, default=str))
reference_impl/primekg_attention.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.
2
+
3
+ Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
4
+ - disease → has_phenotype → phenotype (curated phenotype linkage)
5
+ - disease → associated_with → gene (causal gene evidence)
6
+ - gene → interacts_with → gene (PPI network)
7
+ - phenotype → is_a → phenotype (HPO ontology)
8
+
9
+ Ego-subgraph BFS:
10
+ 1. Start from disease node (ORPHA → PrimeKG index)
11
+ 2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
12
+ 3. 1-hop: pull connected genes
13
+ 4. 2-hop: gene→gene neighbors (interacting partners)
14
+ 5. Concatenate fused embeddings of all selected nodes → cross-attn context
15
+
16
+ Falls back to cosine-similarity if graph not loaded.
17
+
18
+ White-space architecture (May 2026):
19
+ - EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
20
+ - PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
21
+ - We use it for GENERATION (counterfactual trajectory completion)
22
+ """
23
+ from __future__ import annotations
24
+ import os
25
+ import json
26
+ import logging
27
+ from functools import lru_cache
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ log = logging.getLogger("gemeo.cdf.kg")
35
+
36
+ # Try raras-app paths first (richer, including hetero_graph edges + node_texts)
37
+ RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
38
+ LOCAL_KG_DIR = os.path.join(
39
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
40
+
41
+
42
+ def _kg_path(name: str) -> str:
43
+ """Prefer raras-app path if available, fall back to local fp16."""
44
+ raras = os.path.join(RARAS_KG_DIR, name)
45
+ if os.path.exists(raras):
46
+ return raras
47
+ local = os.path.join(LOCAL_KG_DIR, name)
48
+ return local if os.path.exists(local) else None
49
+
50
+
51
+ @lru_cache(maxsize=1)
52
+ def load_kg(prefer_raras: bool = True) -> dict | None:
53
+ """Load PrimeKG: fused embeddings + node ids + edges + texts.
54
+
55
+ Returns dict:
56
+ emb : {kind: torch.Tensor(N, 3072)}
57
+ idx2id : {kind: {pos: id_str}}
58
+ id2idx : {kind: {id_str: pos}}
59
+ edges : {edge_type: {'src': [...], 'dst': [...]}}
60
+ adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed
61
+ texts : {kind: [str, ...]} -- aligned to position
62
+ num_nodes : {kind: int}
63
+ """
64
+ # Try raras-app full file first, then local fp16
65
+ emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
66
+ if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
67
+ else _kg_path("fused_embeddings_fp16.npz"))
68
+ if not emb_path or not os.path.exists(emb_path):
69
+ log.warning("PrimeKG fused embeddings not found")
70
+ return None
71
+
72
+ nids_path = _kg_path("node_ids.json")
73
+ graph_path = _kg_path("hetero_graph.json")
74
+ texts_path = _kg_path("node_texts.json")
75
+
76
+ fz = np.load(emb_path)
77
+ nids = json.load(open(nids_path)) if nids_path else {}
78
+ graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
79
+ texts = json.load(open(texts_path)) if texts_path else {}
80
+
81
+ out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
82
+ "texts": texts, "num_nodes": graph.get("num_nodes", {})}
83
+
84
+ for kind in ("disease", "phenotype", "gene"):
85
+ if kind in fz.files:
86
+ out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
87
+ if kind in nids:
88
+ out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
89
+ out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}
90
+
91
+ # Build adjacency from edges
92
+ for edge_type, edata in graph.get("edges", {}).items():
93
+ adj = {}
94
+ srcs = edata.get("src", []) if isinstance(edata, dict) else []
95
+ dsts = edata.get("dst", []) if isinstance(edata, dict) else []
96
+ for s, d in zip(srcs, dsts):
97
+ adj.setdefault(int(s), []).append(int(d))
98
+ out["adj"][edge_type] = adj
99
+ out["edges"][edge_type] = edata
100
+
101
+ log.info(f" KG loaded from {emb_path}")
102
+ log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, "
103
+ f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
104
+ f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
105
+ log.info(f" edges: {list(out['edges'].keys())}")
106
+ return out
107
+
108
+
109
+ def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
110
+ k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
111
+ """BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).
112
+
113
+ Returns concatenated embeddings (N, 3072) where:
114
+ - 1 disease node (the query)
115
+ - up to k_pheno phenotype nodes (direct edges)
116
+ - up to k_gene gene nodes (direct edges)
117
+ - up to k_gene_2hop gene-gene 2-hop neighbors
118
+
119
+ Falls back to cosine similarity if no edges available.
120
+ """
121
+ if kg is None:
122
+ kg = load_kg()
123
+ if kg is None or "disease" not in kg["emb"]:
124
+ return None
125
+
126
+ d_id = kg["id2idx"]["disease"].get(str(orpha_code))
127
+ if d_id is None:
128
+ return None
129
+
130
+ d_emb = kg["emb"]["disease"][d_id]
131
+ nodes = [d_emb.unsqueeze(0)]
132
+
133
+ # Phenotype neighbors (via disease__has_phenotype__phenotype)
134
+ adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
135
+ pheno_neighbors = adj.get(d_id, [])
136
+ if pheno_neighbors and "phenotype" in kg["emb"]:
137
+ pheno_neighbors = pheno_neighbors[:k_pheno]
138
+ nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
139
+ elif "phenotype" in kg["emb"]:
140
+ # Fallback: cosine similarity
141
+ pool = kg["emb"]["phenotype"]
142
+ sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
143
+ top = sim.topk(min(k_pheno, pool.size(0))).indices
144
+ nodes.append(pool[top])
145
+
146
+ # Gene neighbors (via disease__associated_with__gene)
147
+ g_adj = kg["adj"].get("disease__associated_with__gene", {})
148
+ gene_neighbors = g_adj.get(d_id, [])
149
+ if gene_neighbors and "gene" in kg["emb"]:
150
+ gene_neighbors = gene_neighbors[:k_gene]
151
+ nodes.append(kg["emb"]["gene"][gene_neighbors])
152
+
153
+ # 2-hop: gene-gene neighbors of the genes we just pulled
154
+ if k_gene_2hop > 0:
155
+ gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
156
+ seen = set(gene_neighbors)
157
+ second_hop = []
158
+ for g in gene_neighbors:
159
+ for g2 in gg_adj.get(g, []):
160
+ if g2 not in seen:
161
+ second_hop.append(g2)
162
+ seen.add(g2)
163
+ if len(second_hop) >= k_gene_2hop: break
164
+ if len(second_hop) >= k_gene_2hop: break
165
+ if second_hop:
166
+ nodes.append(kg["emb"]["gene"][second_hop])
167
+ elif "gene" in kg["emb"]:
168
+ pool = kg["emb"]["gene"]
169
+ sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
170
+ top = sim.topk(min(k_gene, pool.size(0))).indices
171
+ nodes.append(pool[top])
172
+
173
+ return torch.cat(nodes, dim=0)
174
+
175
+
176
+ # Keep old API name for backward compat
177
+ ego_subgraph = ego_subgraph_real
178
+
179
+
180
+ class KGCrossAttention(nn.Module):
181
+ """Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
182
+ def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
183
+ super().__init__()
184
+ self.n_heads = n_heads
185
+ self.head_dim = d_model // n_heads
186
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
187
+ self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
188
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
189
+ self.norm_q = nn.LayerNorm(d_model)
190
+ self.norm_kv = nn.LayerNorm(d_model)
191
+ self.dropout = nn.Dropout(dropout)
192
+
193
+ def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
194
+ B, T, D = x_seq.shape
195
+ _, N, _ = x_kg.shape
196
+ q = self.q_proj(self.norm_q(x_seq))
197
+ kv = self.kv_proj(self.norm_kv(x_kg))
198
+ k, v = kv.chunk(2, dim=-1)
199
+ q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
200
+ k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
201
+ v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
202
+ out = F.scaled_dot_product_attention(
203
+ q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
204
+ out = out.transpose(1, 2).reshape(B, T, D)
205
+ return x_seq + self.dropout(self.out_proj(out))
206
+
207
+
208
+ class KGProjector(nn.Module):
209
+ """Project 3072-d KG embeddings to d_model with LayerNorm."""
210
+ def __init__(self, kg_dim: int, d_model: int):
211
+ super().__init__()
212
+ self.proj = nn.Sequential(
213
+ nn.Linear(kg_dim, d_model),
214
+ nn.GELU(),
215
+ nn.LayerNorm(d_model),
216
+ )
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ return self.proj(x)
220
+
221
+
222
+ def build_kg_batch(orpha_strings: list[str], d_model: int,
223
+ projector: KGProjector,
224
+ k_pheno: int = 16, k_gene: int = 16,
225
+ k_gene_2hop: int = 0) -> torch.Tensor:
226
+ """Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.
227
+
228
+ Falls back to zero context for missing ORPHAs.
229
+ """
230
+ kg = load_kg()
231
+ if kg is None:
232
+ return torch.zeros(len(orpha_strings), 1, d_model,
233
+ device=next(projector.parameters()).device)
234
+ N = 1 + k_pheno + k_gene + k_gene_2hop
235
+ egos = []
236
+ for orpha in orpha_strings:
237
+ e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
238
+ if e is None:
239
+ e = torch.zeros(N, kg["emb"]["disease"].size(-1))
240
+ elif e.size(0) < N:
241
+ pad = torch.zeros(N - e.size(0), e.size(-1))
242
+ e = torch.cat([e, pad], dim=0)
243
+ egos.append(e[:N])
244
+ egos = torch.stack(egos, dim=0)
245
+ return projector(egos.to(next(projector.parameters()).device))
246
+
247
+
248
+ def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
249
+ k_pheno: int = 16, k_gene: int = 16,
250
+ batch_size: int = 32) -> torch.Tensor:
251
+ """Pre-compute KG context for an entire dataset in batches.
252
+
253
+ Returns (N_patients, kg_nodes, d_model) tensor on projector device.
254
+ Saves to disk-cacheable format.
255
+ """
256
+ out = []
257
+ for i in range(0, len(orpha_codes), batch_size):
258
+ batch = orpha_codes[i:i + batch_size]
259
+ ctx = build_kg_batch(batch, projector.proj[0].out_features,
260
+ projector, k_pheno, k_gene)
261
+ out.append(ctx.cpu())
262
+ return torch.cat(out, dim=0)
reference_impl/sample.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
2
+
3
+ Diffusion Forcing flexibility — the same model handles:
4
+
5
+ AR mode:
6
+ Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
7
+ transformer but with per-token noise control.
8
+
9
+ Denoise mode (bidirectional):
10
+ Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
11
+
12
+ Counterfactual mode (the TTE primitive):
13
+ Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
14
+ generate. Condition on (cohort, intervention_action_id). Sample N times,
15
+ compare distributions of outcome tokens.
16
+
17
+ CFG (classifier-free guidance) wraps any mode:
18
+ logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
19
+
20
+ Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
21
+ distilled student model — implemented in distill.py.
22
+ """
23
+ from __future__ import annotations
24
+ import logging
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from .diffusion_forcing import CDFTransformer
30
+
31
+ log = logging.getLogger("gemeo.cdf.sample")
32
+
33
+
34
+ @torch.no_grad()
35
+ def sample_denoise(
36
+ model: CDFTransformer,
37
+ cond: torch.Tensor,
38
+ *,
39
+ seed_prefix: torch.Tensor | None = None,
40
+ observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
41
+ action: torch.Tensor | None = None,
42
+ gamma: float = 2.0,
43
+ n_steps: int = 32,
44
+ null_cond: int = 0,
45
+ schedule: str = "cosine",
46
+ ) -> torch.Tensor:
47
+ """Denoise-mode sampling: fully-masked sequence + iterative refinement.
48
+
49
+ Supports:
50
+ - seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
51
+ - observed_mask: arbitrary positions to clamp (counterfactual mode)
52
+ - CFG via (cond, null_cond) pair
53
+ """
54
+ cfg = model.cfg
55
+ device = cond.device
56
+ B = cond.size(0)
57
+ T = cfg.max_seq_len
58
+
59
+ # Init with MASK
60
+ x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
61
+ fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
62
+ if seed_prefix is not None:
63
+ L = seed_prefix.size(1)
64
+ x[:, :L] = seed_prefix
65
+ fixed_mask[:, :L] = True
66
+ if observed_mask is not None:
67
+ fixed_mask |= observed_mask
68
+
69
+ # Build noise schedule
70
+ if schedule == "cosine":
71
+ # smooth cosine from 1 -> 0
72
+ ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
73
+ else:
74
+ ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
75
+
76
+ null = torch.full_like(cond, null_cond)
77
+ null_action = (torch.full_like(action, cfg.n_latent_actions)
78
+ if action is not None and cfg.use_latent_action else None)
79
+
80
+ for k in range(n_steps):
81
+ # Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
82
+ sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
83
+ torch.full((B, T), ts[k].item(), device=device))
84
+ logits_c = model(x, sigma, cond, action)
85
+ if gamma > 0:
86
+ logits_n = model(x, sigma, null, null_action)
87
+ logits = (1 + gamma) * logits_c - gamma * logits_n
88
+ else:
89
+ logits = logits_c
90
+ logits[:, :, cfg.mask_token] = -1e9
91
+
92
+ probs = F.softmax(logits, dim=-1)
93
+ confs, preds = probs.max(dim=-1)
94
+
95
+ # Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
96
+ t_next = ts[k+1].item()
97
+ target_kept = int(round((1 - t_next) * T))
98
+ revealed = (x != cfg.mask_token) | fixed_mask
99
+ already = revealed.sum(dim=-1)
100
+ new_x = x.clone()
101
+ for b in range(B):
102
+ need = max(0, target_kept - int(already[b].item()))
103
+ if need == 0:
104
+ continue
105
+ confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
106
+ topi = confs_b.topk(need).indices
107
+ new_x[b, topi] = preds[b, topi]
108
+ x = new_x
109
+
110
+ # Final cleanup
111
+ mask_left = x == cfg.mask_token
112
+ if mask_left.any():
113
+ sigma_final = torch.zeros(B, T, device=device)
114
+ logits_c = model(x, sigma_final, cond, action)
115
+ if gamma > 0:
116
+ logits_n = model(x, sigma_final, null, null_action)
117
+ logits = (1 + gamma) * logits_c - gamma * logits_n
118
+ else:
119
+ logits = logits_c
120
+ logits[:, :, cfg.mask_token] = -1e9
121
+ preds = logits.argmax(-1)
122
+ x = torch.where(mask_left, preds, x)
123
+ return x
124
+
125
+
126
+ @torch.no_grad()
127
+ def sample_ar(
128
+ model: CDFTransformer,
129
+ cond: torch.Tensor,
130
+ prefix: torch.Tensor,
131
+ *,
132
+ action: torch.Tensor | None = None,
133
+ max_new: int = 50,
134
+ temperature: float = 1.0,
135
+ gamma: float = 0.0,
136
+ null_cond: int = 0,
137
+ ) -> torch.Tensor:
138
+ """AR-mode sampling: future tokens at sigma=1, past at sigma=0.
139
+
140
+ Faster than denoise mode when you only want to continue a prefix.
141
+ """
142
+ cfg = model.cfg
143
+ device = cond.device
144
+ B = cond.size(0)
145
+ x = prefix.clone().to(device)
146
+ if x.dim() == 1: x = x.unsqueeze(0)
147
+ null = torch.full_like(cond, null_cond)
148
+ null_action = (torch.full_like(action, cfg.n_latent_actions)
149
+ if action is not None and cfg.use_latent_action else None)
150
+
151
+ for _ in range(max_new):
152
+ T_now = x.size(1)
153
+ if T_now >= cfg.max_seq_len:
154
+ break
155
+ # Pad with MASK
156
+ x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
157
+ device=device, dtype=torch.long)], dim=1)
158
+ sigma = torch.zeros(B, T_now + 1, device=device)
159
+ sigma[:, -1] = 1.0
160
+ a_pad = None
161
+ if action is not None and cfg.use_latent_action:
162
+ a_pad = torch.cat([action[:, :T_now],
163
+ torch.full((B, 1), cfg.n_latent_actions,
164
+ device=device, dtype=torch.long)], dim=1)
165
+ logits = model(x_pad, sigma, cond, a_pad)
166
+ if gamma > 0:
167
+ logits_n = model(x_pad, sigma, null, null_action)
168
+ logits = (1 + gamma) * logits - gamma * logits_n
169
+ logits[:, :, cfg.mask_token] = -1e9
170
+ p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
171
+ nxt = torch.multinomial(p, 1)
172
+ x = torch.cat([x, nxt], dim=1)
173
+ return x
174
+
175
+
176
+ @torch.no_grad()
177
+ def counterfactual_rollout(
178
+ model: CDFTransformer,
179
+ seed_prefix: torch.Tensor,
180
+ treatment_cond: int,
181
+ untreated_cond: int,
182
+ *,
183
+ treatment_action: int | None = None,
184
+ untreated_action: int | None = None,
185
+ n_samples: int = 100,
186
+ gamma: float = 2.0,
187
+ n_steps: int = 32,
188
+ ) -> dict:
189
+ """Sample paired counterfactual trajectories under treatment vs no-treatment.
190
+
191
+ Two ways to specify the intervention:
192
+ - via cond id (cohort-level): treatment_cond / untreated_cond
193
+ - via latent action id (per-token): treatment_action / untreated_action
194
+ """
195
+ cfg = model.cfg
196
+ device = next(model.parameters()).device
197
+ seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
198
+ T = cfg.max_seq_len
199
+
200
+ cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
201
+ cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
202
+
203
+ action_tx = action_null = None
204
+ if cfg.use_latent_action:
205
+ action_tx = torch.full((n_samples, T),
206
+ treatment_action if treatment_action is not None
207
+ else cfg.n_latent_actions,
208
+ device=device, dtype=torch.long)
209
+ action_null = torch.full((n_samples, T),
210
+ untreated_action if untreated_action is not None
211
+ else cfg.n_latent_actions,
212
+ device=device, dtype=torch.long)
213
+
214
+ traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
215
+ action=action_tx, gamma=gamma, n_steps=n_steps)
216
+ traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
217
+ action=action_null, gamma=gamma, n_steps=n_steps)
218
+ return {
219
+ "traj_treated": traj_tx, "traj_untreated": traj_null,
220
+ "n": n_samples, "treatment_cond": treatment_cond,
221
+ "untreated_cond": untreated_cond, "gamma": gamma,
222
+ }
223
+
224
+
225
+ def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
226
+ if not target_ids:
227
+ return 0.0
228
+ target = torch.tensor(target_ids, device=traj.device)
229
+ has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
230
+ return has.float().mean().item()
reference_impl/wsd_scheduler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.
2
+
3
+ Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
4
+ Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
5
+ Phase 2 (stable, 60-80%): constant peak_lr
6
+ Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1
7
+
8
+ Beats cosine for:
9
+ - data-limited regimes (we can extend stable phase if loss still falls)
10
+ - continue-pretrain (sharp decay enables clean fine-tune handoff)
11
+ """
12
+ from __future__ import annotations
13
+ import math
14
+ import torch
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+
17
+
18
+ def wsd_lr_schedule(step: int, total_steps: int,
19
+ warmup_steps: int = 500,
20
+ stable_frac: float = 0.80,
21
+ decay_frac: float = 0.15,
22
+ min_lr_ratio: float = 0.1,
23
+ decay_type: str = "linear") -> float:
24
+ """Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
25
+ if step < warmup_steps:
26
+ return step / max(1, warmup_steps)
27
+ # remainder of steps after warmup
28
+ remaining = total_steps - warmup_steps
29
+ if remaining <= 0:
30
+ return 1.0
31
+ stable_steps = int(stable_frac * remaining)
32
+ decay_steps = int(decay_frac * remaining)
33
+ pos = step - warmup_steps
34
+ if pos < stable_steps:
35
+ return 1.0
36
+ decay_pos = pos - stable_steps
37
+ if decay_pos >= decay_steps:
38
+ return min_lr_ratio
39
+ progress = decay_pos / max(1, decay_steps)
40
+ if decay_type == "linear":
41
+ return 1.0 - (1.0 - min_lr_ratio) * progress
42
+ elif decay_type == "cosine":
43
+ return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
44
+ elif decay_type == "inv_sqrt":
45
+ return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
46
+ else:
47
+ raise ValueError(f"unknown decay_type: {decay_type}")
48
+
49
+
50
+ def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
51
+ total_steps: int,
52
+ warmup_steps: int = 500,
53
+ stable_frac: float = 0.80,
54
+ decay_frac: float = 0.15,
55
+ min_lr_ratio: float = 0.1,
56
+ decay_type: str = "linear") -> LambdaLR:
57
+ """Build a LambdaLR scheduler with WSD schedule."""
58
+ def fn(step):
59
+ return wsd_lr_schedule(step, total_steps, warmup_steps,
60
+ stable_frac, decay_frac, min_lr_ratio, decay_type)
61
+ return LambdaLR(optimizer, lr_lambda=fn)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ # Visualize the schedule
66
+ total = 10000
67
+ warmup = 500
68
+ print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
69
+ print(f" step lr_mult")
70
+ for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
71
+ m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
72
+ print(f" {s:>5} {m:.4f}")