GEMEO Architecture v1.0 — spec + reference impl + figure
Browse files- .gitattributes +1 -0
- README.md +66 -0
- figure1_gemeo_architecture.png +3 -0
- gemeo_architecture_spec_v1.md +254 -0
- reference_impl/__init__.py +33 -0
- reference_impl/adaln_zero.py +148 -0
- reference_impl/diffusion_forcing_v13.py +314 -0
- reference_impl/eval_sota.py +290 -0
- reference_impl/meds_export.py +336 -0
- reference_impl/primekg_attention.py +262 -0
- reference_impl/sample.py +230 -0
- reference_impl/wsd_scheduler.py +72 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figure1_gemeo_architecture.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
language: [en, pt]
|
| 4 |
+
tags:
|
| 5 |
+
- world-model
|
| 6 |
+
- patient-digital-twin
|
| 7 |
+
- reference-architecture
|
| 8 |
+
- diffusion-forcing
|
| 9 |
+
- meds
|
| 10 |
+
- primekg
|
| 11 |
+
- rare-disease
|
| 12 |
+
library_name: pytorch
|
| 13 |
+
pipeline_tag: time-series-forecasting
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# GEMEO Architecture v1.0
|
| 17 |
+
|
| 18 |
+
> **A reference architecture for patient world models.** Six principles,
|
| 19 |
+
> pluggable substrate, three open instances. *Not a model* — a recipe for
|
| 20 |
+
> building a model.
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
This repo contains the **architecture specification** and a **reference
|
| 25 |
+
implementation** (Apache-2.0 source, no weights). To use:
|
| 26 |
+
|
| 27 |
+
1. Read [`gemeo_architecture_spec_v1.md`](./gemeo_architecture_spec_v1.md)
|
| 28 |
+
for the 6-principle conformance definition.
|
| 29 |
+
2. Copy `reference_impl/` into your repo, adapt to your substrate (any
|
| 30 |
+
MEDS v0.4.1-compliant EHR), train.
|
| 31 |
+
3. Name your instance `gemeo-<substrate>-v<n>` and submit a conformance
|
| 32 |
+
report.
|
| 33 |
+
|
| 34 |
+
## Open instances (May 2026)
|
| 35 |
+
|
| 36 |
+
| Instance | Substrate | Params | Status |
|
| 37 |
+
|---|---|---|---|
|
| 38 |
+
| [`Raras-AI/gemeo-sus-v2`](https://huggingface.co/Raras-AI/gemeo-sus-v2) | DATASUS (Brazil, 42K patients) | 19.86M | ✅ released |
|
| 39 |
+
| [`Raras-AI/gemeo-twin-stack`](https://huggingface.co/Raras-AI/gemeo-twin-stack) | application layer | NeuralSurv + heads | ✅ released |
|
| 40 |
+
| `Raras-AI/gemeo-mayo-v3` | Mayo Clinic Platform (planned) | 300M | in proposal |
|
| 41 |
+
| `Raras-AI/gemeo-mimic-demo` | MIMIC-IV-DEMO | reference impl | in progress |
|
| 42 |
+
|
| 43 |
+
## The six architectural principles
|
| 44 |
+
|
| 45 |
+
1. **Diffusion Forcing backbone** with per-token σ ∼ 𝒰(0, 1)
|
| 46 |
+
2. **Gated KG cross-attention** with tanh(α), α init = 0; real PrimeKG edges
|
| 47 |
+
3. **MEDS v0.4.1 substrate** — `(subject_id, time, code, value)`
|
| 48 |
+
4. **Bootstrap-then-learn** pattern per inference mode
|
| 49 |
+
5. **Bidirectional health-system grounding** (formulary re-rank)
|
| 50 |
+
6. **Audit-driven training** (Chinchilla scaling + SOTA component validation)
|
| 51 |
+
|
| 52 |
+
Full definitions and conformance tests in [`gemeo_architecture_spec_v1.md`](./gemeo_architecture_spec_v1.md).
|
| 53 |
+
|
| 54 |
+
## Citation
|
| 55 |
+
|
| 56 |
+
```bibtex
|
| 57 |
+
@misc{gemeo_arch_v1_2026,
|
| 58 |
+
title = {GEMEO Architecture Specification v1.0:
|
| 59 |
+
Reference architecture for patient world models},
|
| 60 |
+
author = {Verdial, Dimas Quintas and Kawassaki, Alexandre and the Raras AI team},
|
| 61 |
+
year = {2026},
|
| 62 |
+
url = {https://huggingface.co/Raras-AI/gemeo-arch},
|
| 63 |
+
}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
⚠️ Research only. Not a medical device. No clinical use.
|
figure1_gemeo_architecture.png
ADDED
|
Git LFS Details
|
gemeo_architecture_spec_v1.md
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: "GEMEO Architecture Specification v1.0"
|
| 3 |
+
subtitle: "Reference architecture for patient world models and healthcare digital twins"
|
| 4 |
+
author: "Raras AI"
|
| 5 |
+
date: "May 2026"
|
| 6 |
+
geometry: margin=1in
|
| 7 |
+
fontsize: 11pt
|
| 8 |
+
mainfont: "Inter"
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# GEMEO Architecture Specification v1.0
|
| 12 |
+
|
| 13 |
+
> **Status:** Stable. This document defines GEMEO Architecture v1.0 — a reference design for **patient world models** in the lineage of Dreamer (Hafner 2019–2025), Diffusion Forcing (Chen NeurIPS 2024), Sora (OpenAI 2024), and Genie (DeepMind 2024–2025), applied to clinical event streams.
|
| 14 |
+
>
|
| 15 |
+
> **License:** CC-BY-NC 4.0 (specification text). Reference implementation in `Raras-AI/gemeo-arch` is Apache-2.0.
|
| 16 |
+
>
|
| 17 |
+
> **Authors:** Raras AI team. Correspondence: dimas@raras.ai
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## 0. Scope and motivation
|
| 22 |
+
|
| 23 |
+
A *patient world model* is a learned generative model of patient-trajectory dynamics. It should support: (a) trajectory rollout conditional on actions, (b) counterfactual reasoning under interventions, (c) risk and outcome inference, and (d) drug repurposing. GEMEO Architecture v1.0 is the smallest concrete design we have found that satisfies all four under realistic constraints — sparse clinical events, missing labels, asymmetric data substrates, and regulatory grounding.
|
| 24 |
+
|
| 25 |
+
This document defines the architecture so that any institution holding longitudinal patient data can instantiate **its own GEMEO model** without sharing data with us. Implementations are expected to ship distinct *instances* (e.g. `gemeo-sus-v2`, `gemeo-mayo-v3`, `gemeo-mimic-demo`) while remaining GEMEO-conformant per §6 of this spec.
|
| 26 |
+
|
| 27 |
+
## 1. The six architectural principles
|
| 28 |
+
|
| 29 |
+
A model is **GEMEO-conformant** if and only if it implements all six principles below. Optional extensions are allowed (§5) but do not affect conformance.
|
| 30 |
+
|
| 31 |
+
### Principle 1 — Diffusion Forcing backbone with per-token σ
|
| 32 |
+
|
| 33 |
+
The trajectory backbone MUST be a transformer with **per-token noise level** σᵢ ∼ 𝒰(0, 1) sampled independently at each position during training (Chen et al., NeurIPS 2024 [DF]). Score-SDE backbones, pure autoregressive backbones, and uniform-σ backbones are **not** GEMEO-conformant.
|
| 34 |
+
|
| 35 |
+
Rationale: per-token σ enables variable-horizon rollout and gap-filling under one trained model, which is necessary for both the trajectory mode and the bidirectional masked-fill mode used in clinical-evidence reconstruction.
|
| 36 |
+
|
| 37 |
+
### Principle 2 — Gated cross-attention to a biomedical knowledge graph
|
| 38 |
+
|
| 39 |
+
The model MUST attend to a *real* biomedical knowledge-graph ego-subgraph (PrimeKG [Chandak 2023] or equivalent), via cross-attention with a tanh-gated residual:
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
h' = h + tanh(α) · CrossAttn(h, KG_ego)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
with α initialised to zero. The KG cross-attention MUST be inserted at ≥ 1 layer of the backbone. The KG ego-subgraph MUST contain real disease–gene and disease–phenotype edges; *simulated* or *RAG-only* KG conditioning is not GEMEO-conformant.
|
| 46 |
+
|
| 47 |
+
Rationale: zero-initialised gating (Flamingo [Alayrac 2022], Genie [Bruce 2024]) lets the backbone train as an unconditional model first and adopt KG signal only when it improves the loss, avoiding early-training instability. Real edges anchor the latent state to known biology and prevent hallucinated trajectories under counterfactual rollout.
|
| 48 |
+
|
| 49 |
+
### Principle 3 — MEDS v0.4.1 substrate
|
| 50 |
+
|
| 51 |
+
Event input MUST be in the **Medical Event Data Standard v0.4.1** [McDermott 2024]. Each event is a tuple `(subject_id, time, code, numeric_value, text_value)`. Code-prefix patterns MUST follow the canonical MEDS conventions: `ICD10//`, `SIH//`, `APAC//`, `SIGTAP//`, `ORPHA//`, with `MEDS_BIRTH` and `MEDS_DEATH` reserved. The MEDS schema MUST be validated at export time via the official `meds.DataSchema`.
|
| 52 |
+
|
| 53 |
+
Rationale: MEDS is the lingua franca for EHR foundation models (EHRSHOT, CoMET, CLMBR, MOTOR all use MEDS-derivatives). Compliance ensures any GEMEO instance can be benchmarked against any MEDS-based competitor and that EHR substrates can be swapped without re-engineering the model.
|
| 54 |
+
|
| 55 |
+
### Principle 4 — Bootstrap-then-learn pattern per inference mode
|
| 56 |
+
|
| 57 |
+
Each inference mode (trajectory, diagnosis, risk, counterfactual, repurposing, cohort) MUST expose the SAME function signature regardless of whether a learned head exists. When no learned head exists, the function MUST return a deterministic, rule-based or LLM-bootstrap output flagged `model="bootstrap"`. When a learned head exists, the function MUST return the learned output flagged with the head identifier (e.g. `model="neuralsurv"`).
|
| 58 |
+
|
| 59 |
+
This is not a software-engineering nicety. It is a deployment guarantee: a clinic can deploy a GEMEO instance before any GPU-trained head exists. Capabilities turn on monotonically as checkpoints land in the runtime.
|
| 60 |
+
|
| 61 |
+
### Principle 5 — Bidirectional health-system grounding
|
| 62 |
+
|
| 63 |
+
Therapy recommendations MUST be re-ranked by the patient's actual health-system formulary. The re-ranking score MUST include at minimum:
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
score_grounding(r, p) = π_formulary(r, dx) · ρ_region(r, p_region) · (1 − dist(p, c_ref) / D_max)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
where π_formulary ∈ {0, 1} marks formulary membership (PCDT in Brazil, NICE in UK, CMS-approved in US, etc.), ρ_region is the per-region empirical dispensation rate, and c_ref is the closest specialised referral centre. GEMEO instances that recommend therapies the local health system does not deliver are not conformant.
|
| 70 |
+
|
| 71 |
+
Rationale: clinical AI deployed without grounding to local formularies is not actionable. This principle distinguishes GEMEO from US-only or geography-blind foundation models.
|
| 72 |
+
|
| 73 |
+
### Principle 6 — Audit-driven training
|
| 74 |
+
|
| 75 |
+
Every architectural decision MUST be auditable against a contemporaneous SOTA reference. Specifically:
|
| 76 |
+
|
| 77 |
+
- **Compute scaling**: parameter count MUST respect Chinchilla scaling for the available training tokens. (Our v1 made the inverse error — 125M params on 20M tokens, 30–60× too large — and was downsized to 19.86M in v2.)
|
| 78 |
+
- **Component validation**: each module (DF objective, KG cross-attention, WSD schedule, etc.) MUST cite the paper that established it as SOTA at training time and MUST be ablatable.
|
| 79 |
+
- **Vocabulary integrity**: token codes MUST match the canonical MEDS prefixes (e.g. `ICD10//`, not `CID10//`). Synthetic or hallucinated KG edges are forbidden.
|
| 80 |
+
|
| 81 |
+
A training run that cannot answer "why this layer, why this size, why this schedule" with a citation does not produce a GEMEO-conformant model.
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 2. Required components and module structure
|
| 86 |
+
|
| 87 |
+
A reference implementation MUST expose the following Python module structure:
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
gemeo/
|
| 91 |
+
├── __init__.py
|
| 92 |
+
├── core.py # orchestrator
|
| 93 |
+
├── api.py # FastAPI surface (/api/gemeo/*)
|
| 94 |
+
├── types.py # typed dataclasses
|
| 95 |
+
│
|
| 96 |
+
├── cdf/ # Principle 1 + 2: the world model itself
|
| 97 |
+
│ ├── diffusion_forcing.py # CDF transformer
|
| 98 |
+
│ ├── adaln_zero.py # DiT-style σ conditioning
|
| 99 |
+
│ ├── primekg_attention.py # gated KG cross-attention
|
| 100 |
+
│ ├── wsd_scheduler.py # WSD LR schedule
|
| 101 |
+
│ ├── meds_export.py # MEDS v0.4.1 schema validation
|
| 102 |
+
│ └── sample.py # inference / rollout
|
| 103 |
+
│
|
| 104 |
+
├── encoder.py # static patient embedding (HGT or bootstrap)
|
| 105 |
+
├── cohort.py # patients-like-mine (Principle 4 bootstrap)
|
| 106 |
+
├── subgraph.py # KG sparsification
|
| 107 |
+
├── trajectory.py # trajectory mode (delegates to cdf/)
|
| 108 |
+
├── risk.py # NeuralSurv head (Principle 4)
|
| 109 |
+
├── repurpose.py # drug repurposing (TxGNN slot)
|
| 110 |
+
├── whatif.py # counterfactual engine
|
| 111 |
+
├── ask.py # active learning
|
| 112 |
+
├── ground_sus.py # Principle 5: health-system grounding
|
| 113 |
+
└── feedback.py # closed-loop label ingestion
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
A subset is allowed — any instance MAY omit modes it does not need (e.g. a research-only instance might skip `ground_sus.py`). But if a module IS present, it MUST conform to the signatures in `types.py`.
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## 3. Training recipe
|
| 121 |
+
|
| 122 |
+
### 3.1 Required training-time properties
|
| 123 |
+
|
| 124 |
+
- **Per-token σ ∼ 𝒰(0.01, 0.99)** sampled independently. Absorbing-state corruption: position *i* becomes MASK with probability σᵢ.
|
| 125 |
+
- **Loss**: masked cross-entropy over corrupted positions, with **Min-SNR weighting** (Hang 2023) on per-token loss.
|
| 126 |
+
- **Conditional dropout 10%**: replace `cond` with `<NULL>` token for classifier-free guidance support at inference.
|
| 127 |
+
- **WSD LR schedule**: 5% warmup / 80% stable / 15% linear decay (Hu et al. MiniCPM 2024). Cosine schedules are not GEMEO-conformant.
|
| 128 |
+
- **bf16 mixed precision**.
|
| 129 |
+
- **Embedding tying** between input and output projections.
|
| 130 |
+
- **Architecture**: SwiGLU + RMSNorm + RoPE (Llama-style). Standard transformer with LayerNorm + ReLU is not conformant.
|
| 131 |
+
|
| 132 |
+
### 3.2 Required validation gates
|
| 133 |
+
|
| 134 |
+
Before publication, an instance MUST report:
|
| 135 |
+
|
| 136 |
+
- Validation cross-entropy on a held-out random-σ task.
|
| 137 |
+
- Integrated Calibration Index (ICI) per Austin & Steyerberg.
|
| 138 |
+
- Gap-fill Top-K (K = 1, 3, 5, 10) on a contiguous mid-trajectory mask of size 24.
|
| 139 |
+
- Multi-horizon Top-K decay curve at prefix lengths *k* ∈ {32, 64, 128}.
|
| 140 |
+
- Per-event-class macro-AUROC with explicit honest reporting of horizon limitations (see §5 of the GEMEO/SUS-v2 model paper for the canonical honest format).
|
| 141 |
+
- Per-subgroup fairness check (sex, age band, region).
|
| 142 |
+
|
| 143 |
+
Instances that do not report all six are non-conformant for publication-grade claims.
|
| 144 |
+
|
| 145 |
+
### 3.3 Compute budget guidance
|
| 146 |
+
|
| 147 |
+
- **Tier-S** (≤ 10M tokens): backbone ≤ 20M params, single H100 < 1 hour, single instance.
|
| 148 |
+
- **Tier-M** (10M–500M tokens): backbone 20–80M params, 8–24 H100-hours.
|
| 149 |
+
- **Tier-L** (500M–10B tokens): backbone 80M–300M params, 100+ H100-hours.
|
| 150 |
+
- **Tier-XL** (10B+ tokens): backbone 300M+ params, 1000+ H100-hours. Mayo / EHRSHOT / multi-modal class.
|
| 151 |
+
|
| 152 |
+
A tier-S instance with claimed tier-XL performance is suspicious by construction.
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## 4. Conformance tests
|
| 157 |
+
|
| 158 |
+
A model is GEMEO-conformant if it passes the following automated tests, runnable via `gemeo-bench`:
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
gemeo-bench check Raras-AI/your-gemeo-instance
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
**Required tests:**
|
| 165 |
+
|
| 166 |
+
1. `test_schema_meds` — model exports MEDS-validatable event streams.
|
| 167 |
+
2. `test_per_token_sigma` — model accepts per-position σ vector.
|
| 168 |
+
3. `test_kg_gating_init` — α = 0 at initialisation; sanity-check the gate.
|
| 169 |
+
4. `test_gap_fill_recovery` — gap [24, 48) Top-10 ≥ 0.50 on test split (a real GEMEO instance, even tiny, recovers gaps far above random).
|
| 170 |
+
5. `test_bootstrap_paths` — every inference mode returns a value with `model={bootstrap, learned}` identifier.
|
| 171 |
+
6. `test_health_system_grounding` — at least one therapy recommendation differs by patient region.
|
| 172 |
+
7. `test_audit_citations` — model card contains citations for: DF, AdaLN-Zero, WSD, MEDS, KG source, gating pattern.
|
| 173 |
+
|
| 174 |
+
The reference test suite is bundled with `Raras-AI/gemeo-arch` and runs in < 60 seconds on CPU for a tiny instance.
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## 5. Optional extensions (not required for conformance)
|
| 179 |
+
|
| 180 |
+
- **Self Forcing training** (NeurIPS 2025 Spotlight [Self Forcing 2025]) — addresses tail exposure bias. Recommended for v2+ instances.
|
| 181 |
+
- **Positional features** beyond RoPE — explicit age / calendar-year / region embeddings concatenated with the token embedding. Useful when temporal sparsity is high.
|
| 182 |
+
- **Multimodal substrate** — clinical notes via Gemini / Llama-Med phenotype extraction, WES variant tokens, imaging features. These extend the MEDS event vocabulary but should retain MEDS-prefix conventions.
|
| 183 |
+
- **CoMET-style multi-sample inference** — Monte-Carlo aggregation of *n* trajectories at inference. Note: not required, and per the GEMEO/SUS-v2 ablation in §7.6 of the model paper, it did not outperform the one-shot trained protocol at tier-S.
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
## 6. Versioning and naming
|
| 188 |
+
|
| 189 |
+
- The architecture spec is versioned (`v1.0` here). Backward-incompatible changes increment the major version.
|
| 190 |
+
- Model instances are named `gemeo/<substrate>-v<version>` (e.g. `gemeo/sus-v2`, `gemeo/mayo-v3`, `gemeo/mimic-demo-v1`).
|
| 191 |
+
- The reference implementation is `Raras-AI/gemeo-arch` (architecture, no weights). Instances live at `Raras-AI/gemeo-<substrate>-v<n>`.
|
| 192 |
+
- Bundled application layer: `Raras-AI/gemeo-twin-stack` (the six-mode wrapper around any conformant instance).
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
## 7. Validated instances (May 2026)
|
| 197 |
+
|
| 198 |
+
| Instance | Substrate | Params | Status | Reference |
|
| 199 |
+
|---|---|---|---|---|
|
| 200 |
+
| `gemeo-sus-v2` | DATASUS (Brazil, 42 K patients) | 19.86 M | ✅ open | Verdial 2026 |
|
| 201 |
+
| `gemeo-mayo-v3` | Mayo Clinic Platform (planned, 3 M) | 300 M (planned) | in proposal | Mayo Accelerate proposal |
|
| 202 |
+
| `gemeo-mimic-demo` | MIMIC-IV-DEMO (Open, ~100 pts) | reference impl | in progress | — |
|
| 203 |
+
|
| 204 |
+
External instances are welcome. Submit a pull request to `Raras-AI/gemeo-arch` with the conformance-test output to be listed.
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## 8. Related architectures (not GEMEO-conformant, by design)
|
| 209 |
+
|
| 210 |
+
We name these explicitly so that reviewers and adopters can distinguish:
|
| 211 |
+
|
| 212 |
+
- **Sora** [OpenAI 2024]: world model for video, not patient events. DF lineage but no clinical substrate.
|
| 213 |
+
- **Dreamer 4** [Hafner 2025]: world model for general environments. DF lineage. Not specialised to MEDS, no KG cross-attention, no health-system grounding.
|
| 214 |
+
- **EHRWorld** [arXiv 2602.03569]: autoregressive Qwen-fine-tune on EHR. Not Diffusion Forcing. Not conformant to Principle 1.
|
| 215 |
+
- **CoMET** [Epic 2508.12104]: large-scale generative medical event model. Not Diffusion Forcing — autoregressive only. Not conformant to Principle 1.
|
| 216 |
+
- **MOTOR** [Steinberg 2023]: time-to-event foundation model. Not generative-dynamics. Conformant to substrate (MEDS) but not to Principles 1, 2.
|
| 217 |
+
- **CLMBR-T** [Stanford]: contrastive foundation model. Not generative. Not conformant to Principle 1.
|
| 218 |
+
- **Delphi-2M** [Nature 2025]: autoregressive transformer over UK Biobank + Danish registry. Not Diffusion Forcing. Not conformant to Principle 1.
|
| 219 |
+
- **RareGraph-Synth** [arXiv 2510.06267]: score-SDE diffusion on *synthetic* rare-disease graphs. Not GEMEO-conformant: synthetic substrate (violates Principle 6's audit requirements) and score-SDE not Diffusion Forcing.
|
| 220 |
+
|
| 221 |
+
GEMEO is the architecture that combines (DF + KG cross-attn + MEDS + health-system grounding + bootstrap-then-learn + audit) into a single deployable design. To our knowledge, no prior architecture satisfies all six.
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## 9. Citation
|
| 226 |
+
|
| 227 |
+
```bibtex
|
| 228 |
+
@misc{gemeo_arch_v1_2026,
|
| 229 |
+
title = {GEMEO Architecture Specification v1.0:
|
| 230 |
+
Reference architecture for patient world models},
|
| 231 |
+
author = {Verdial, Dimas Quintas and Kawassaki, Alexandre and the Raras AI team},
|
| 232 |
+
year = {2026},
|
| 233 |
+
url = {https://huggingface.co/Raras-AI/gemeo-arch},
|
| 234 |
+
note = {Specification document. Reference implementation Apache-2.0.}
|
| 235 |
+
}
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## 10. References
|
| 239 |
+
|
| 240 |
+
- [DF] Chen, B. et al. *Diffusion Forcing: Next-Token Prediction Meets Full-Sequence Diffusion.* NeurIPS 2024 (arXiv:2407.01392).
|
| 241 |
+
- [Dreamer 4] Hafner, D. et al. *Training Agents Inside of Scalable World Models.* arXiv:2509.24527, Sept 2025.
|
| 242 |
+
- [Sora] OpenAI. *Video generation models as world simulators.* Feb 2024.
|
| 243 |
+
- [Genie] Bruce, J. et al. *Genie: Generative Interactive Environments.* ICML 2024.
|
| 244 |
+
- [PrimeKG] Chandak, P., Huang, K., Zitnik, M. *Building a knowledge graph to enable precision medicine.* Nature Sci Data 2023.
|
| 245 |
+
- [MEDS] McDermott, M. et al. *MEDS: Medical Event Data Standard v0.4.1.* GitHub, 2024.
|
| 246 |
+
- [DiT] Peebles, W. & Xie, S. *Scalable Diffusion Models with Transformers.* ICCV 2023.
|
| 247 |
+
- [Flamingo] Alayrac, J.-B. et al. *Flamingo: a Visual Language Model for Few-Shot Learning.* NeurIPS 2022.
|
| 248 |
+
- [MiniCPM] Hu, S. et al. *MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies.* 2024.
|
| 249 |
+
- [Min-SNR] Hang, T. et al. *Efficient Diffusion Training via Min-SNR Weighting Strategy.* ICCV 2023.
|
| 250 |
+
- [Self Forcing 2025] *Self Forcing: Training Diffusion-Forcing Agents Without Exposure Bias.* arXiv:2506.08009, NeurIPS 2025 Spotlight.
|
| 251 |
+
|
| 252 |
+
---
|
| 253 |
+
|
| 254 |
+
*End of GEMEO Architecture Specification v1.0.*
|
reference_impl/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GEMEO-CDF: Causal Diffusion Forcing for clinical trajectories.
|
| 2 |
+
|
| 3 |
+
Three "first in medicine" hooks:
|
| 4 |
+
1. DIFFUSION FORCING (Chen MIT NeurIPS 2024 → Dreamer 4 Hafner 2025 backbone)
|
| 5 |
+
— independent per-token noise levels unify AR + diffusion + counterfactual
|
| 6 |
+
in ONE loss. Zero clinical port as of May 2026.
|
| 7 |
+
|
| 8 |
+
2. LATENT ACTION MODEL (Genie / DeepMind 2024)
|
| 9 |
+
— VQ-VAE codebook over (state_t, state_{t+1}) deltas discovers a
|
| 10 |
+
treatment vocabulary without RxNorm/ATC labels. Solves the APAC
|
| 11 |
+
miscoding / sparsity / off-label labelling pain in DATASUS.
|
| 12 |
+
|
| 13 |
+
3. PROCESS REWARD VERIFIER (o3 / MAI-DxO 2025 pattern)
|
| 14 |
+
— small PRM scores top-K rollouts at inference, returns top-1 +
|
| 15 |
+
uncertainty band. Deliberative trajectory generation, novel in EHR.
|
| 16 |
+
|
| 17 |
+
Modules:
|
| 18 |
+
diffusion_forcing.py — core architecture (per-token noise + block-causal)
|
| 19 |
+
lam.py — Latent Action Model (VQ-VAE codebook)
|
| 20 |
+
train_cdf.py — training loop with diffusion forcing objective
|
| 21 |
+
sample.py — sampling: AR mode / denoise mode / counterfactual
|
| 22 |
+
distill.py — Shortcut Forcing distillation (Dreamer 4)
|
| 23 |
+
prm.py — Process Reward Verifier
|
| 24 |
+
"""
|
| 25 |
+
from .diffusion_forcing import CDFTransformer, CDFConfig
|
| 26 |
+
from .lam import LatentActionVQVAE, LAMConfig
|
| 27 |
+
from .train_cdf import train_cdf
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"CDFTransformer", "CDFConfig",
|
| 31 |
+
"LatentActionVQVAE", "LAMConfig",
|
| 32 |
+
"train_cdf",
|
| 33 |
+
]
|
reference_impl/adaln_zero.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
|
| 2 |
+
|
| 3 |
+
Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
|
| 4 |
+
PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
|
| 5 |
+
|
| 6 |
+
Why for Diffusion Forcing on EHR:
|
| 7 |
+
- Per-token sigma + global cond/action → per-token (scale, shift, gate)
|
| 8 |
+
- Gates init to zero ⇒ block starts as identity ⇒ no catastrophic init
|
| 9 |
+
- Much better CFG (dropped condition path goes through zero gates,
|
| 10 |
+
not corrupting residual stream)
|
| 11 |
+
- DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
|
| 12 |
+
|
| 13 |
+
We fuse THREE conditioning signals:
|
| 14 |
+
- sigma (B, T) per-token noise level → time_emb (B, T, D)
|
| 15 |
+
- cond (B,) cohort-level treatment id → cond_emb (B, D) → broadcast
|
| 16 |
+
- action(B, T) per-token latent action id → action_emb (B, T, D)
|
| 17 |
+
|
| 18 |
+
Combined into c_t (B, T, D) → ConditioningMLP → 6 modulation tensors
|
| 19 |
+
per block. Each block uses them as:
|
| 20 |
+
|
| 21 |
+
h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
|
| 22 |
+
h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AdaLNZeroModulator(nn.Module):
|
| 30 |
+
"""Generates per-token (scale, shift, gate) for AdaLN-Zero block.
|
| 31 |
+
|
| 32 |
+
Input: fused conditioning vector c (B, T, d_model).
|
| 33 |
+
Output: 6 tensors of shape (B, T, d_model) each:
|
| 34 |
+
(scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, d_model: int):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.modulator = nn.Sequential(
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(d_model, 6 * d_model, bias=True),
|
| 41 |
+
)
|
| 42 |
+
# Zero-init for the gate-producing rows (AdaLN-Zero trick)
|
| 43 |
+
# We zero-init ALL outputs initially; gate stays zero so block is identity
|
| 44 |
+
nn.init.zeros_(self.modulator[-1].weight)
|
| 45 |
+
nn.init.zeros_(self.modulator[-1].bias)
|
| 46 |
+
|
| 47 |
+
def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
| 48 |
+
# c: (B, T, d_model)
|
| 49 |
+
out = self.modulator(c) # (B, T, 6*d_model)
|
| 50 |
+
return out.chunk(6, dim=-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AdaLNZeroBlock(nn.Module):
|
| 54 |
+
"""Transformer block with AdaLN-Zero modulation.
|
| 55 |
+
|
| 56 |
+
Drop-in replacement for the standard pre-norm block. Reads
|
| 57 |
+
pre-computed modulation tensors and applies them around Attn + MLP.
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
|
| 60 |
+
rope=None, kg_xattn=None):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.d_model = d_model
|
| 63 |
+
self.n_heads = n_heads
|
| 64 |
+
self.head_dim = d_model // n_heads
|
| 65 |
+
self.rope = rope
|
| 66 |
+
self.kg_xattn = kg_xattn
|
| 67 |
+
|
| 68 |
+
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 69 |
+
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 70 |
+
self.proj = nn.Linear(d_model, d_model, bias=False)
|
| 71 |
+
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 72 |
+
self.mlp = nn.Sequential(
|
| 73 |
+
nn.Linear(d_model, ffn, bias=False),
|
| 74 |
+
nn.GELU(),
|
| 75 |
+
nn.Linear(ffn, d_model, bias=False),
|
| 76 |
+
)
|
| 77 |
+
self.dropout = nn.Dropout(dropout)
|
| 78 |
+
|
| 79 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
|
| 80 |
+
scale_msa, shift_msa, gate_msa,
|
| 81 |
+
scale_mlp, shift_mlp, gate_mlp,
|
| 82 |
+
kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
|
| 83 |
+
import torch.nn.functional as F
|
| 84 |
+
B, T, D = x.shape
|
| 85 |
+
# MSA branch
|
| 86 |
+
h = self.norm1(x) * (1 + scale_msa) + shift_msa
|
| 87 |
+
qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
|
| 88 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 89 |
+
if self.rope is not None:
|
| 90 |
+
q, k = self.rope(q, k, T)
|
| 91 |
+
out = F.scaled_dot_product_attention(
|
| 92 |
+
q, k, v,
|
| 93 |
+
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
|
| 94 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 95 |
+
)
|
| 96 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 97 |
+
x = x + gate_msa * self.dropout(self.proj(out))
|
| 98 |
+
# KG cross-attention (between MSA and MLP)
|
| 99 |
+
if self.kg_xattn is not None and kg_ctx is not None:
|
| 100 |
+
x = self.kg_xattn(x, kg_ctx)
|
| 101 |
+
# MLP branch
|
| 102 |
+
h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
|
| 103 |
+
x = x + gate_mlp * self.dropout(self.mlp(h))
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FusedConditioner(nn.Module):
|
| 108 |
+
"""Fuse (sigma, cond, action) into one per-token conditioning vector.
|
| 109 |
+
|
| 110 |
+
Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, d_model: int, n_conditions: int, n_actions: int,
|
| 113 |
+
use_action: bool = True):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.d_model = d_model
|
| 116 |
+
self.use_action = use_action
|
| 117 |
+
# Sigma → sinusoidal embedding
|
| 118 |
+
self.sigma_proj = nn.Sequential(
|
| 119 |
+
nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
|
| 120 |
+
)
|
| 121 |
+
self.cond_emb = nn.Embedding(n_conditions, d_model)
|
| 122 |
+
if use_action:
|
| 123 |
+
self.action_emb = nn.Embedding(n_actions + 1, d_model)
|
| 124 |
+
self.fuse = nn.Sequential(
|
| 125 |
+
nn.SiLU(),
|
| 126 |
+
nn.Linear(d_model, d_model),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
import math
|
| 131 |
+
half = self.d_model // 2
|
| 132 |
+
freqs = torch.exp(
|
| 133 |
+
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
|
| 134 |
+
)
|
| 135 |
+
ang = sigma.float().unsqueeze(-1) * freqs
|
| 136 |
+
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
|
| 137 |
+
return self.sigma_proj(emb)
|
| 138 |
+
|
| 139 |
+
def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
|
| 140 |
+
action: torch.Tensor | None = None) -> torch.Tensor:
|
| 141 |
+
# sigma (B, T) → time_emb (B, T, D)
|
| 142 |
+
time_emb = self.sinusoidal(sigma)
|
| 143 |
+
# cond (B,) → (B, D) → broadcast to (B, T, D)
|
| 144 |
+
cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
|
| 145 |
+
fused = time_emb + cond_emb
|
| 146 |
+
if self.use_action and action is not None:
|
| 147 |
+
fused = fused + self.action_emb(action)
|
| 148 |
+
return self.fuse(fused)
|
reference_impl/diffusion_forcing_v13.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.
|
| 2 |
+
|
| 3 |
+
Per the SOTA audit (May 2026):
|
| 4 |
+
- Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
|
| 5 |
+
- Path A adopted: small from-scratch model + KG adapters + MEDS interop
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
- 12M backbone params (Chinchilla-respecting for ~20M token corpus)
|
| 9 |
+
- d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
|
| 10 |
+
- SwiGLU MLP (ffn:d_model = 2.67)
|
| 11 |
+
- Tied embeddings (saves ~12M at vocab=32k)
|
| 12 |
+
- Dropout 0.1 everywhere (small-data critical)
|
| 13 |
+
- Block-causal attention (Diffusion Forcing)
|
| 14 |
+
- Per-token sigma noise (independent)
|
| 15 |
+
- GATED KG cross-attention (tanh(α)·xattn, α init=0)
|
| 16 |
+
- Layers 4, 6, 7 (3 of 8)
|
| 17 |
+
- Lets model learn to use KG progressively, doesn't disrupt early loss
|
| 18 |
+
- DF objective + LM-aux loss (joint training, paper-grade)
|
| 19 |
+
|
| 20 |
+
Sources audited:
|
| 21 |
+
- CoMET (Aug 2025): tokens-per-param ratio
|
| 22 |
+
- CLMBR (Stanford): adapter pattern for cross-site transfer
|
| 23 |
+
- MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
|
| 24 |
+
- Genie (DeepMind 2024): gated cross-attention pattern
|
| 25 |
+
- SD3 (Esser 2024): AdaLN-Zero zero-init gates
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
import math
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class CDFv13Config:
|
| 38 |
+
# Vocab + sequence
|
| 39 |
+
vocab_size: int = 32768 # MEDS-derived (will be much smaller in practice)
|
| 40 |
+
mask_token: int = 32767
|
| 41 |
+
max_seq_len: int = 512
|
| 42 |
+
block_size: int = 16
|
| 43 |
+
# Architecture (Chinchilla-correct for ~20M tokens)
|
| 44 |
+
d_model: int = 384
|
| 45 |
+
n_heads: int = 6
|
| 46 |
+
n_layers: int = 8
|
| 47 |
+
ffn: int = 1024 # SwiGLU effective; flag below uses 2 projections
|
| 48 |
+
dropout: float = 0.1
|
| 49 |
+
emb_dropout: float = 0.1
|
| 50 |
+
use_swiglu: bool = True
|
| 51 |
+
use_rmsnorm: bool = True
|
| 52 |
+
tie_embeddings: bool = True
|
| 53 |
+
# Diffusion forcing
|
| 54 |
+
cond_dropout: float = 0.10
|
| 55 |
+
# KG conditioning (GATED adapters)
|
| 56 |
+
use_kg: bool = True
|
| 57 |
+
kg_dim: int = 3072
|
| 58 |
+
kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
|
| 59 |
+
# Latent action
|
| 60 |
+
use_latent_action: bool = False # Dropped per audit (concept shaky)
|
| 61 |
+
n_latent_actions: int = 512
|
| 62 |
+
# Conditioning
|
| 63 |
+
n_conditions: int = 64
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class RMSNorm(nn.Module):
|
| 67 |
+
"""Root-mean-square LayerNorm (LLaMA/Mistral style)."""
|
| 68 |
+
def __init__(self, d: int, eps: float = 1e-6):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.weight = nn.Parameter(torch.ones(d))
|
| 71 |
+
self.eps = eps
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
| 75 |
+
return (norm * self.weight.float()).to(x.dtype)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SwiGLU(nn.Module):
|
| 79 |
+
"""SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
|
| 80 |
+
def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
|
| 83 |
+
self.w_up = nn.Linear(d_in, d_hidden, bias=False)
|
| 84 |
+
self.w_down = nn.Linear(d_hidden, d_in, bias=False)
|
| 85 |
+
self.dropout = nn.Dropout(dropout)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class RotaryEmbedding(nn.Module):
|
| 92 |
+
"""RoPE (Su et al. 2021)."""
|
| 93 |
+
def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
|
| 94 |
+
super().__init__()
|
| 95 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 96 |
+
t = torch.arange(max_seq).float()
|
| 97 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 98 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 99 |
+
self.register_buffer("cos", emb.cos(), persistent=False)
|
| 100 |
+
self.register_buffer("sin", emb.sin(), persistent=False)
|
| 101 |
+
|
| 102 |
+
def forward(self, q, k, seq_len):
|
| 103 |
+
cos = self.cos[:seq_len].to(q.dtype).to(q.device)
|
| 104 |
+
sin = self.sin[:seq_len].to(q.dtype).to(q.device)
|
| 105 |
+
def rot_half(x):
|
| 106 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 107 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 108 |
+
return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class PerTokenSigmaEmbed(nn.Module):
|
| 112 |
+
"""Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
|
| 113 |
+
def __init__(self, d: int):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.d = d
|
| 116 |
+
self.proj = nn.Sequential(
|
| 117 |
+
nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def forward(self, sigma: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
half = self.d // 2
|
| 122 |
+
freqs = torch.exp(
|
| 123 |
+
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
|
| 124 |
+
)
|
| 125 |
+
ang = sigma.float().unsqueeze(-1) * freqs
|
| 126 |
+
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
|
| 127 |
+
return self.proj(emb)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class GatedKGCrossAttention(nn.Module):
|
| 131 |
+
"""Cross-attention to KG ego-subgraph, with GATED output.
|
| 132 |
+
|
| 133 |
+
`tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
|
| 134 |
+
initialized to 0. This means at init the cross-attention contributes
|
| 135 |
+
NOTHING to the residual stream, so the model trains identically to
|
| 136 |
+
no-KG until it discovers KG is useful. Prevents catastrophic loss
|
| 137 |
+
spikes on small data.
|
| 138 |
+
|
| 139 |
+
Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
|
| 140 |
+
"""
|
| 141 |
+
def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.n_heads = n_heads
|
| 144 |
+
self.head_dim = d_model // n_heads
|
| 145 |
+
# Project KG to d_model (run inline so we don't need separate KGProjector module)
|
| 146 |
+
self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
|
| 147 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 148 |
+
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
|
| 149 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 150 |
+
self.norm_q = RMSNorm(d_model)
|
| 151 |
+
self.norm_kv = RMSNorm(d_model)
|
| 152 |
+
self.dropout = nn.Dropout(dropout)
|
| 153 |
+
# Gate (scalar per block, init=0)
|
| 154 |
+
self.alpha = nn.Parameter(torch.zeros(1))
|
| 155 |
+
|
| 156 |
+
def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
"""
|
| 158 |
+
x_seq: (B, T, d_model)
|
| 159 |
+
kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072)
|
| 160 |
+
"""
|
| 161 |
+
B, T, D = x_seq.shape
|
| 162 |
+
kg_proj = self.kg_in_proj(kg_raw) # (B, N_kg, D)
|
| 163 |
+
N_kg = kg_proj.size(1)
|
| 164 |
+
q = self.q_proj(self.norm_q(x_seq))
|
| 165 |
+
kv = self.kv_proj(self.norm_kv(kg_proj))
|
| 166 |
+
k, v = kv.chunk(2, dim=-1)
|
| 167 |
+
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 168 |
+
k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
|
| 169 |
+
v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
|
| 170 |
+
out = F.scaled_dot_product_attention(
|
| 171 |
+
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
|
| 172 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 173 |
+
gate = torch.tanh(self.alpha)
|
| 174 |
+
return x_seq + gate * self.dropout(self.out_proj(out))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class CDFv13Block(nn.Module):
|
| 178 |
+
"""Pre-norm transformer block + optional gated KG cross-attn."""
|
| 179 |
+
def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
|
| 180 |
+
layer_idx: int):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.cfg = cfg
|
| 183 |
+
self.rope = rope
|
| 184 |
+
self.layer_idx = layer_idx
|
| 185 |
+
norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
|
| 186 |
+
self.norm1 = norm_cls(cfg.d_model)
|
| 187 |
+
self.norm2 = norm_cls(cfg.d_model)
|
| 188 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 189 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 190 |
+
if cfg.use_swiglu:
|
| 191 |
+
self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
|
| 192 |
+
else:
|
| 193 |
+
self.mlp = nn.Sequential(
|
| 194 |
+
nn.Linear(cfg.d_model, cfg.ffn, bias=False),
|
| 195 |
+
nn.GELU(),
|
| 196 |
+
nn.Linear(cfg.ffn, cfg.d_model, bias=False),
|
| 197 |
+
nn.Dropout(cfg.dropout),
|
| 198 |
+
)
|
| 199 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 200 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 201 |
+
# Gated KG cross-attention (only in specified layers)
|
| 202 |
+
self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
|
| 203 |
+
if self.use_kg_in_layer:
|
| 204 |
+
self.kg_xattn = GatedKGCrossAttention(
|
| 205 |
+
cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
|
| 206 |
+
|
| 207 |
+
def forward(self, x, attn_mask, kg_raw=None):
|
| 208 |
+
B, T, D = x.shape
|
| 209 |
+
# MSA
|
| 210 |
+
h = self.norm1(x)
|
| 211 |
+
qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
|
| 212 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 213 |
+
q, k = self.rope(q, k, T)
|
| 214 |
+
out = F.scaled_dot_product_attention(
|
| 215 |
+
q, k, v,
|
| 216 |
+
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
|
| 217 |
+
dropout_p=self.cfg.dropout if self.training else 0.0,
|
| 218 |
+
)
|
| 219 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 220 |
+
x = x + self.dropout(self.proj(out))
|
| 221 |
+
# Gated KG cross-attn (if enabled at this layer)
|
| 222 |
+
if self.use_kg_in_layer and kg_raw is not None:
|
| 223 |
+
x = self.kg_xattn(x, kg_raw)
|
| 224 |
+
# MLP
|
| 225 |
+
x = x + self.mlp(self.norm2(x))
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class CDFv13Transformer(nn.Module):
|
| 230 |
+
"""Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""
|
| 231 |
+
|
| 232 |
+
def __init__(self, cfg: CDFv13Config | None = None):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.cfg = cfg or CDFv13Config()
|
| 235 |
+
c = self.cfg
|
| 236 |
+
norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm
|
| 237 |
+
|
| 238 |
+
self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
|
| 239 |
+
self.emb_dropout = nn.Dropout(c.emb_dropout)
|
| 240 |
+
|
| 241 |
+
# Per-token sigma embedding (additive)
|
| 242 |
+
self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
|
| 243 |
+
# Global condition embedding (additive, broadcast)
|
| 244 |
+
self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
|
| 245 |
+
|
| 246 |
+
# RoPE
|
| 247 |
+
self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)
|
| 248 |
+
|
| 249 |
+
# Blocks
|
| 250 |
+
self.blocks = nn.ModuleList([
|
| 251 |
+
CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
|
| 252 |
+
])
|
| 253 |
+
self.final_norm = norm_cls(c.d_model)
|
| 254 |
+
self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
|
| 255 |
+
if c.tie_embeddings:
|
| 256 |
+
self.head.weight = self.tok_emb.weight
|
| 257 |
+
|
| 258 |
+
# Block-causal mask buffer
|
| 259 |
+
T = c.max_seq_len
|
| 260 |
+
block_id = torch.arange(T) // c.block_size
|
| 261 |
+
mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
|
| 262 |
+
self.register_buffer("block_mask", mask, persistent=False)
|
| 263 |
+
|
| 264 |
+
# Init
|
| 265 |
+
self.apply(self._init_weights)
|
| 266 |
+
|
| 267 |
+
def _init_weights(self, m):
|
| 268 |
+
if isinstance(m, nn.Linear):
|
| 269 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 270 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 271 |
+
elif isinstance(m, nn.Embedding):
|
| 272 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 273 |
+
|
| 274 |
+
def forward(self, x, sigma, cond, kg_raw=None):
|
| 275 |
+
B, T = x.shape
|
| 276 |
+
h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
|
| 277 |
+
h = self.emb_dropout(h)
|
| 278 |
+
mask = self.block_mask[:T, :T]
|
| 279 |
+
for blk in self.blocks:
|
| 280 |
+
h = blk(h, mask, kg_raw=kg_raw)
|
| 281 |
+
h = self.final_norm(h)
|
| 282 |
+
return self.head(h)
|
| 283 |
+
|
| 284 |
+
def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
|
| 285 |
+
mode: str = "uniform") -> torch.Tensor:
|
| 286 |
+
"""Standard absorbing-state DF loss with per-token sigma.
|
| 287 |
+
|
| 288 |
+
mode: 'uniform' (default — safer for discrete than logit-normal per audit)
|
| 289 |
+
'logit_normal' (SD3-style — keep as ablation only)
|
| 290 |
+
"""
|
| 291 |
+
B, T = x_clean.shape
|
| 292 |
+
device = x_clean.device
|
| 293 |
+
# CFG cond dropout
|
| 294 |
+
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
|
| 295 |
+
cond = torch.where(drop, torch.zeros_like(cond), cond)
|
| 296 |
+
if kg_raw is not None:
|
| 297 |
+
drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
|
| 298 |
+
kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
|
| 299 |
+
# Sample per-token sigma
|
| 300 |
+
if mode == "logit_normal":
|
| 301 |
+
sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
|
| 302 |
+
else:
|
| 303 |
+
sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
|
| 304 |
+
# Absorbing-state corruption
|
| 305 |
+
corrupt = torch.rand(B, T, device=device) < sigma
|
| 306 |
+
x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
|
| 307 |
+
logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
|
| 308 |
+
ce = F.cross_entropy(
|
| 309 |
+
logits.reshape(-1, self.cfg.vocab_size),
|
| 310 |
+
x_clean.reshape(-1),
|
| 311 |
+
reduction="none",
|
| 312 |
+
).reshape(B, T)
|
| 313 |
+
n = corrupt.float().sum().clamp(min=1.0)
|
| 314 |
+
return (ce * corrupt.float()).sum() / n
|
reference_impl/eval_sota.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SOTA evaluation suite for CDFv13 — audit-proof.
|
| 2 |
+
|
| 3 |
+
Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
|
| 4 |
+
with the canonical EHR foundation model metric stack:
|
| 5 |
+
|
| 6 |
+
Classification (next-event, downstream tasks):
|
| 7 |
+
- AUROC + AUPRC + Brier
|
| 8 |
+
- Calibration: ICI (Austin & Steyerberg 2019)
|
| 9 |
+
- Decision-curve analysis (Vickers)
|
| 10 |
+
- Bootstrap 95% CI (≥2000 resamples) — required for rare disease
|
| 11 |
+
|
| 12 |
+
Survival (DATASUS SIM mortality):
|
| 13 |
+
- Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
|
| 14 |
+
- Integrated Brier Score (1/3/5y)
|
| 15 |
+
- Time-dependent AUC
|
| 16 |
+
|
| 17 |
+
Counterfactual / causal:
|
| 18 |
+
- ATE with bootstrap CI
|
| 19 |
+
- E-value (VanderWeele)
|
| 20 |
+
- Negative-control outcome + exposure
|
| 21 |
+
- Tipping-point analysis
|
| 22 |
+
|
| 23 |
+
Generation fidelity (CoMET / SynthEHRella):
|
| 24 |
+
- Dim-wise probability match
|
| 25 |
+
- MMD (Maximum Mean Discrepancy) with RBF kernel
|
| 26 |
+
- TSTR (Train-on-Synthetic-Test-on-Real)
|
| 27 |
+
|
| 28 |
+
Subgroup fairness (npj DM requirement):
|
| 29 |
+
- Stratified metrics: sex, age band, UF region
|
| 30 |
+
|
| 31 |
+
Split strategy (DATASUS rare disease):
|
| 32 |
+
- Temporal: train ≤2022, val 2023, test 2024-2025
|
| 33 |
+
- Geographic: train SE+S, test N+NE (UF cross-region = "external")
|
| 34 |
+
- Patient-level 5-fold CV (variance estimation)
|
| 35 |
+
"""
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
import math
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
from typing import Callable
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------- Classification ----------
|
| 44 |
+
|
| 45 |
+
def auroc(y: np.ndarray, p: np.ndarray) -> float:
|
| 46 |
+
from sklearn.metrics import roc_auc_score
|
| 47 |
+
if len(np.unique(y)) < 2: return float("nan")
|
| 48 |
+
return roc_auc_score(y, p)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def auprc(y: np.ndarray, p: np.ndarray) -> float:
|
| 52 |
+
from sklearn.metrics import average_precision_score
|
| 53 |
+
if len(np.unique(y)) < 2: return float("nan")
|
| 54 |
+
return average_precision_score(y, p)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def brier(y: np.ndarray, p: np.ndarray) -> float:
|
| 58 |
+
from sklearn.metrics import brier_score_loss
|
| 59 |
+
return brier_score_loss(y, p)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
|
| 63 |
+
"""Integrated Calibration Index (Austin & Steyerberg 2019).
|
| 64 |
+
Lowess-smoothed deviation from perfect calibration.
|
| 65 |
+
"""
|
| 66 |
+
from statsmodels.nonparametric.smoothers_lowess import lowess
|
| 67 |
+
sm = lowess(y, p, frac=frac, return_sorted=True)
|
| 68 |
+
return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
|
| 72 |
+
"""Net benefit at a given decision threshold (Vickers DCA)."""
|
| 73 |
+
tp = ((p >= threshold) & (y == 1)).sum()
|
| 74 |
+
fp = ((p >= threshold) & (y == 0)).sum()
|
| 75 |
+
n = len(y)
|
| 76 |
+
if threshold >= 1.0: return 0.0
|
| 77 |
+
return tp / n - (fp / n) * (threshold / (1 - threshold))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def decision_curve(y: np.ndarray, p: np.ndarray,
|
| 81 |
+
thresholds: list[float] = None) -> dict:
|
| 82 |
+
"""Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
|
| 83 |
+
if thresholds is None:
|
| 84 |
+
thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
|
| 85 |
+
model_nb = [net_benefit(y, p, t) for t in thresholds]
|
| 86 |
+
treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
|
| 87 |
+
for t in thresholds]
|
| 88 |
+
treat_none_nb = [0.0] * len(thresholds)
|
| 89 |
+
return {
|
| 90 |
+
"thresholds": thresholds,
|
| 91 |
+
"model": model_nb,
|
| 92 |
+
"treat_all": treat_all_nb,
|
| 93 |
+
"treat_none": treat_none_nb,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
|
| 98 |
+
n_boot: int = 2000, seed: int = 0,
|
| 99 |
+
ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
|
| 100 |
+
"""Bootstrap 95% CI for any (y, p) -> scalar metric."""
|
| 101 |
+
rng = np.random.default_rng(seed)
|
| 102 |
+
n = len(y)
|
| 103 |
+
stats = []
|
| 104 |
+
for _ in range(n_boot):
|
| 105 |
+
idx = rng.integers(0, n, n)
|
| 106 |
+
if len(np.unique(y[idx])) < 2: continue
|
| 107 |
+
try:
|
| 108 |
+
stats.append(metric_fn(y[idx], p[idx]))
|
| 109 |
+
except Exception:
|
| 110 |
+
continue
|
| 111 |
+
if not stats: return (float("nan"),) * 3
|
| 112 |
+
return (
|
| 113 |
+
float(np.percentile(stats, ci[0])),
|
| 114 |
+
float(np.median(stats)),
|
| 115 |
+
float(np.percentile(stats, ci[1])),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------- Survival ----------
|
| 120 |
+
|
| 121 |
+
def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
|
| 122 |
+
risk_score, tau: float = None) -> float:
|
| 123 |
+
"""Uno's C-index (IPCW concordance), preferred at high censoring.
|
| 124 |
+
Requires scikit-survival.
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
from sksurv.metrics import concordance_index_ipcw
|
| 128 |
+
except ImportError:
|
| 129 |
+
return float("nan")
|
| 130 |
+
# Build structured arrays
|
| 131 |
+
y_train = np.array(
|
| 132 |
+
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
|
| 133 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 134 |
+
)
|
| 135 |
+
y_test = np.array(
|
| 136 |
+
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
|
| 137 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 138 |
+
)
|
| 139 |
+
if tau is None:
|
| 140 |
+
tau = float(y_test_time.max()) * 0.95
|
| 141 |
+
c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
|
| 142 |
+
return float(c)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
|
| 146 |
+
surv_pred: np.ndarray, times: np.ndarray) -> float:
|
| 147 |
+
"""Integrated Brier Score (lower is better)."""
|
| 148 |
+
try:
|
| 149 |
+
from sksurv.metrics import integrated_brier_score as ibs_fn
|
| 150 |
+
except ImportError:
|
| 151 |
+
return float("nan")
|
| 152 |
+
y_train = np.array(
|
| 153 |
+
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
|
| 154 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 155 |
+
)
|
| 156 |
+
y_test = np.array(
|
| 157 |
+
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
|
| 158 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 159 |
+
)
|
| 160 |
+
return float(ibs_fn(y_train, y_test, surv_pred, times))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------- Causal / Counterfactual ----------
|
| 164 |
+
|
| 165 |
+
def e_value(rr: float) -> float:
|
| 166 |
+
"""E-value (VanderWeele & Ding 2017): min strength of unmeasured
|
| 167 |
+
confounder needed to explain away an observed RR.
|
| 168 |
+
"""
|
| 169 |
+
rr = max(rr, 1e-9)
|
| 170 |
+
if rr >= 1.0:
|
| 171 |
+
return rr + math.sqrt(rr * (rr - 1))
|
| 172 |
+
rr_inv = 1.0 / rr
|
| 173 |
+
return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
|
| 177 |
+
"""Negative-control outcome: ATE on a control outcome should be ~0."""
|
| 178 |
+
return abs(nc_ate) < threshold
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def tipping_point(observed_effect: float, ci_half_width: float) -> float:
|
| 182 |
+
"""How much would unmeasured confounding need to shift effect to nullify?"""
|
| 183 |
+
if abs(observed_effect) <= ci_half_width:
|
| 184 |
+
return 0.0
|
| 185 |
+
return float(abs(observed_effect) - ci_half_width)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ---------- Generation fidelity (SynthEHRella triad) ----------
|
| 189 |
+
|
| 190 |
+
def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
|
| 191 |
+
vocab_size: int) -> float:
|
| 192 |
+
"""Compare per-token Bernoulli rates between real and synthetic batches.
|
| 193 |
+
|
| 194 |
+
Returns mean abs difference (lower = closer match).
|
| 195 |
+
"""
|
| 196 |
+
real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
|
| 197 |
+
synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
|
| 198 |
+
return float((real_one_hot - synth_one_hot).abs().mean())
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
|
| 202 |
+
"""Maximum Mean Discrepancy with RBF kernel.
|
| 203 |
+
|
| 204 |
+
x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
|
| 205 |
+
"""
|
| 206 |
+
def rbf(a, b):
|
| 207 |
+
d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
|
| 208 |
+
return torch.exp(-d / (2 * sigma ** 2))
|
| 209 |
+
return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ---------- Subgroup fairness ----------
|
| 213 |
+
|
| 214 |
+
def stratified_metrics(y: np.ndarray, p: np.ndarray,
|
| 215 |
+
groups: np.ndarray,
|
| 216 |
+
metric_fn: Callable = auroc) -> dict[str, float]:
|
| 217 |
+
"""Compute metric per subgroup (sex, age band, UF region)."""
|
| 218 |
+
out = {}
|
| 219 |
+
for g in np.unique(groups):
|
| 220 |
+
mask = groups == g
|
| 221 |
+
if mask.sum() > 10:
|
| 222 |
+
try:
|
| 223 |
+
out[str(g)] = metric_fn(y[mask], p[mask])
|
| 224 |
+
except Exception:
|
| 225 |
+
out[str(g)] = float("nan")
|
| 226 |
+
return out
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ---------- DATASUS split strategies ----------
|
| 230 |
+
|
| 231 |
+
def temporal_split(events: list[dict], train_until: int = 2022,
|
| 232 |
+
val_year: int = 2023):
|
| 233 |
+
"""Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
|
| 234 |
+
train, val, test = [], [], []
|
| 235 |
+
for e in events:
|
| 236 |
+
y = e.get("year") or 2020
|
| 237 |
+
if y <= train_until: train.append(e)
|
| 238 |
+
elif y == val_year: val.append(e)
|
| 239 |
+
else: test.append(e)
|
| 240 |
+
return train, val, test
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def geographic_split(patients: list[dict], external_ufs: set = None):
|
| 244 |
+
"""Geographic split: train on SE+S, test on N+NE.
|
| 245 |
+
For DATASUS this is the closest analog to "external validation."
|
| 246 |
+
"""
|
| 247 |
+
if external_ufs is None:
|
| 248 |
+
external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
|
| 249 |
+
"PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
|
| 250 |
+
train, test = [], []
|
| 251 |
+
for p in patients:
|
| 252 |
+
uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
|
| 253 |
+
None)
|
| 254 |
+
(test if uf in external_ufs else train).append(p)
|
| 255 |
+
return train, test
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------- Combined eval report ----------
|
| 259 |
+
|
| 260 |
+
def full_eval_report(y: np.ndarray, p: np.ndarray,
|
| 261 |
+
groups_sex: np.ndarray = None,
|
| 262 |
+
groups_age: np.ndarray = None,
|
| 263 |
+
groups_uf: np.ndarray = None,
|
| 264 |
+
n_boot: int = 2000) -> dict:
|
| 265 |
+
"""Generate a full audit-proof report for a binary classification task.
|
| 266 |
+
|
| 267 |
+
Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
|
| 268 |
+
"""
|
| 269 |
+
import torch.nn.functional as F # local import to keep top clean
|
| 270 |
+
|
| 271 |
+
auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
|
| 272 |
+
auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
|
| 273 |
+
brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)
|
| 274 |
+
|
| 275 |
+
report = {
|
| 276 |
+
"n_eval": len(y),
|
| 277 |
+
"prevalence": float(y.mean()),
|
| 278 |
+
"auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
|
| 279 |
+
"auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
|
| 280 |
+
"brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
|
| 281 |
+
"ici": ici(y, p),
|
| 282 |
+
"decision_curve": decision_curve(y, p),
|
| 283 |
+
}
|
| 284 |
+
if groups_sex is not None:
|
| 285 |
+
report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
|
| 286 |
+
if groups_age is not None:
|
| 287 |
+
report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
|
| 288 |
+
if groups_uf is not None:
|
| 289 |
+
report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
|
| 290 |
+
return report
|
reference_impl/meds_export.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDS v0.4.1 exporter for DATASUS — audit-proof, interop-ready.
|
| 2 |
+
|
| 3 |
+
Verified against:
|
| 4 |
+
- meds 0.4.1 schemas (DataSchema, CodeMetadataSchema)
|
| 5 |
+
- https://github.com/Medical-Event-Data-Standard/meds
|
| 6 |
+
- CLMBR/MOTOR/EHRSHOT/CoMET tokenization conventions
|
| 7 |
+
|
| 8 |
+
Code conventions (interop-compatible):
|
| 9 |
+
- Static (time=None): GENDER//, RACE//, UF//, MUN//, ORPHA//
|
| 10 |
+
- Birth/Death: MEDS_BIRTH, MEDS_DEATH (reserved)
|
| 11 |
+
- Diagnoses: ICD10//<cid> (NOT CID10// — interop with OHDSI/Athena)
|
| 12 |
+
- Hospitalization: SIH//ADM, SIH//DIS (numeric_value=LOS_days on DIS)
|
| 13 |
+
- Procedures: SIGTAP//<10-digit> (Brazil-local namespace)
|
| 14 |
+
- Drugs (APAC): APAC//<sigtap> (numeric_value=monthly_cost_brl)
|
| 15 |
+
- Outpatient (BPA-I): BPAI//<sigtap>
|
| 16 |
+
- Visits: Visit//{IP, OP, ER} (matches CLMBR convention)
|
| 17 |
+
|
| 18 |
+
Outputs canonical MEDS dataset:
|
| 19 |
+
/out/
|
| 20 |
+
├── data/ # parquet shards by subject
|
| 21 |
+
│ ├── shard_0.parquet
|
| 22 |
+
│ └── ...
|
| 23 |
+
├── metadata/
|
| 24 |
+
│ └── codes.parquet # REQUIRED: every unique code with description + parent_codes
|
| 25 |
+
└── dataset_metadata.json # MEDS dataset metadata
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
import os
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
from collections import defaultdict, Counter
|
| 32 |
+
from datetime import datetime
|
| 33 |
+
from typing import Iterator
|
| 34 |
+
|
| 35 |
+
import pyarrow as pa
|
| 36 |
+
import pyarrow.parquet as pq
|
| 37 |
+
import meds
|
| 38 |
+
|
| 39 |
+
log = logging.getLogger("gemeo.cdf.meds_export")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _parse_date(s) -> datetime | None:
|
| 43 |
+
"""Parse date string from various DATASUS formats."""
|
| 44 |
+
if s is None: return None
|
| 45 |
+
s = str(s).strip()
|
| 46 |
+
if not s or s in ("0", "None", "nan"): return None
|
| 47 |
+
try:
|
| 48 |
+
if "-" in s:
|
| 49 |
+
return datetime.strptime(s[:10], "%Y-%m-%d")
|
| 50 |
+
if len(s) == 8:
|
| 51 |
+
return datetime.strptime(s, "%Y%m%d")
|
| 52 |
+
except ValueError:
|
| 53 |
+
return None
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _ym(year, month) -> datetime | None:
|
| 58 |
+
if year is None: return None
|
| 59 |
+
try:
|
| 60 |
+
return datetime(int(year), int(month) if month else 1, 1)
|
| 61 |
+
except (ValueError, TypeError):
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def datasus_patient_to_meds_rows(p: dict, subject_id: int) -> list[tuple]:
|
| 66 |
+
"""Convert one DATASUS patient trajectory to a list of MEDS rows.
|
| 67 |
+
|
| 68 |
+
Each row is (subject_id, time, code, numeric_value, text_value).
|
| 69 |
+
Returns rows ready to write to a parquet shard.
|
| 70 |
+
"""
|
| 71 |
+
rows = []
|
| 72 |
+
|
| 73 |
+
# ---- Static (time=None) ----
|
| 74 |
+
if p.get("sex"):
|
| 75 |
+
rows.append((subject_id, None, f"GENDER//{p['sex']}", None, None))
|
| 76 |
+
# ORPHA is rare-disease specific (parallel to ICD10)
|
| 77 |
+
for orpha in p.get("orphas", []):
|
| 78 |
+
rows.append((subject_id, None, f"ORPHA//{orpha}", None, None))
|
| 79 |
+
|
| 80 |
+
# ---- Birth (use birth_year as Jan 1) ----
|
| 81 |
+
birth_year = p.get("birth_year")
|
| 82 |
+
birth_dt = datetime(int(birth_year), 1, 1) if birth_year else None
|
| 83 |
+
if birth_dt:
|
| 84 |
+
rows.append((subject_id, birth_dt, "MEDS_BIRTH", None, None))
|
| 85 |
+
|
| 86 |
+
# ---- Events ----
|
| 87 |
+
for e in p.get("events", []):
|
| 88 |
+
et = e.get("type")
|
| 89 |
+
|
| 90 |
+
if et == "admission": # SIH-RD
|
| 91 |
+
t = _ym(e.get("year"), e.get("month")) or _parse_date(e.get("admission_date"))
|
| 92 |
+
if not t: continue
|
| 93 |
+
rows.append((subject_id, t, "SIH//ADM", None, None))
|
| 94 |
+
rows.append((subject_id, t, "Visit//IP", None, None))
|
| 95 |
+
cid = e.get("cid_princ", "")
|
| 96 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 97 |
+
proc = e.get("primary_procedure")
|
| 98 |
+
if proc: rows.append((subject_id, t, f"SIGTAP//{proc[:10]}", None, None))
|
| 99 |
+
los = e.get("los_days")
|
| 100 |
+
disch_dt = _parse_date(e.get("discharge_date")) or t
|
| 101 |
+
if e.get("death_during_stay"):
|
| 102 |
+
rows.append((subject_id, disch_dt, "MEDS_DEATH", None, None))
|
| 103 |
+
else:
|
| 104 |
+
rows.append((subject_id, disch_dt, "SIH//DIS",
|
| 105 |
+
float(los) if los is not None else None, None))
|
| 106 |
+
|
| 107 |
+
elif et == "treatment": # APAC-SIA orphan drug
|
| 108 |
+
t = _ym(e.get("year"), e.get("month"))
|
| 109 |
+
if not t: continue
|
| 110 |
+
cid = e.get("cid", "")
|
| 111 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 112 |
+
proc = e.get("procedure_code", "")[:10]
|
| 113 |
+
if proc:
|
| 114 |
+
cost = e.get("monthly_cost_brl")
|
| 115 |
+
rows.append((subject_id, t, f"APAC//{proc}",
|
| 116 |
+
float(cost) if cost is not None else None, None))
|
| 117 |
+
|
| 118 |
+
elif et == "outpatient_proc": # BPA-I
|
| 119 |
+
t = _parse_date(e.get("auth_date")) or _ym(e.get("year"), e.get("month"))
|
| 120 |
+
if not t: continue
|
| 121 |
+
cid = e.get("cid", "")
|
| 122 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 123 |
+
proc = e.get("procedure_code", "")[:10]
|
| 124 |
+
if proc:
|
| 125 |
+
rows.append((subject_id, t, f"BPAI//{proc}", None, None))
|
| 126 |
+
|
| 127 |
+
elif et == "death": # SIM
|
| 128 |
+
t = _parse_date(e.get("date_of_death")) or _ym(e.get("year"), e.get("month"))
|
| 129 |
+
if not t: continue
|
| 130 |
+
rows.append((subject_id, t, "MEDS_DEATH", None, None))
|
| 131 |
+
cid = (e.get("cause_cid") or e.get("cid_princ") or e.get("cid", ""))
|
| 132 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 133 |
+
|
| 134 |
+
# Sort: nulls first (static), then by time
|
| 135 |
+
rows.sort(key=lambda r: (r[1] is not None, r[1] or datetime(1900, 1, 1)))
|
| 136 |
+
return rows
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def export_to_meds(patients: list[dict], out_dir: str,
|
| 140 |
+
shard_size: int = 5000,
|
| 141 |
+
dataset_name: str = "GEMEO-DATASUS",
|
| 142 |
+
version: str = "v13"):
|
| 143 |
+
"""Export a list of DATASUS patient trajectories to MEDS v0.4.1 format.
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
patients : list of dict
|
| 148 |
+
Each dict must have: patient_id, sex, birth_year, orphas (list),
|
| 149 |
+
events (list of dicts with 'type', 'year', 'month', etc.)
|
| 150 |
+
out_dir : str
|
| 151 |
+
Output directory (will create data/ and metadata/ subdirs)
|
| 152 |
+
shard_size : int
|
| 153 |
+
Number of subjects per parquet shard
|
| 154 |
+
"""
|
| 155 |
+
os.makedirs(f"{out_dir}/data", exist_ok=True)
|
| 156 |
+
os.makedirs(f"{out_dir}/metadata", exist_ok=True)
|
| 157 |
+
|
| 158 |
+
log.info(f"Exporting {len(patients)} patients to MEDS at {out_dir}")
|
| 159 |
+
|
| 160 |
+
# Map patient_id (string hash) → int64 subject_id (MEDS requires int64)
|
| 161 |
+
pid_to_sid = {p["patient_id"]: i for i, p in enumerate(patients)}
|
| 162 |
+
|
| 163 |
+
# ---- Stream rows ----
|
| 164 |
+
all_codes = Counter()
|
| 165 |
+
shard_idx = 0
|
| 166 |
+
shard_rows = []
|
| 167 |
+
n_events = 0
|
| 168 |
+
n_subjects = 0
|
| 169 |
+
|
| 170 |
+
for p in patients:
|
| 171 |
+
sid = pid_to_sid[p["patient_id"]]
|
| 172 |
+
rows = datasus_patient_to_meds_rows(p, sid)
|
| 173 |
+
shard_rows.extend(rows)
|
| 174 |
+
n_events += len(rows)
|
| 175 |
+
n_subjects += 1
|
| 176 |
+
for r in rows:
|
| 177 |
+
all_codes[r[2]] += 1
|
| 178 |
+
# Write shard when full
|
| 179 |
+
if n_subjects % shard_size == 0 and shard_rows:
|
| 180 |
+
_write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
|
| 181 |
+
shard_idx += 1
|
| 182 |
+
shard_rows = []
|
| 183 |
+
|
| 184 |
+
# Write remaining
|
| 185 |
+
if shard_rows:
|
| 186 |
+
_write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
|
| 187 |
+
|
| 188 |
+
log.info(f" wrote {shard_idx + 1} data shards, {n_events} rows, {n_subjects} subjects")
|
| 189 |
+
|
| 190 |
+
# ---- codes.parquet (REQUIRED in MEDS v0.4) ----
|
| 191 |
+
code_rows = []
|
| 192 |
+
for code, count in all_codes.most_common():
|
| 193 |
+
# parent_codes: empty for Brazil-local namespaces; populated for ICD10 -> SNOMED if mapped
|
| 194 |
+
parent_codes = _get_parent_codes(code)
|
| 195 |
+
code_rows.append({
|
| 196 |
+
"code": code,
|
| 197 |
+
"description": _get_description(code, count),
|
| 198 |
+
"parent_codes": parent_codes,
|
| 199 |
+
})
|
| 200 |
+
code_table = pa.Table.from_pylist(code_rows, schema=meds.CodeMetadataSchema.schema())
|
| 201 |
+
pq.write_table(code_table, f"{out_dir}/metadata/codes.parquet")
|
| 202 |
+
log.info(f" wrote metadata/codes.parquet ({len(code_rows)} unique codes)")
|
| 203 |
+
|
| 204 |
+
# ---- dataset_metadata.json ----
|
| 205 |
+
md = {
|
| 206 |
+
"dataset_name": dataset_name,
|
| 207 |
+
"dataset_version": version,
|
| 208 |
+
"etl_name": "gemeo.cdf.meds_export",
|
| 209 |
+
"etl_version": "1.0.0",
|
| 210 |
+
"meds_version": meds.__version__,
|
| 211 |
+
"n_subjects": n_subjects,
|
| 212 |
+
"n_events": n_events,
|
| 213 |
+
"n_unique_codes": len(all_codes),
|
| 214 |
+
"top_codes": dict(all_codes.most_common(30)),
|
| 215 |
+
}
|
| 216 |
+
with open(f"{out_dir}/dataset_metadata.json", "w") as f:
|
| 217 |
+
json.dump(md, f, indent=2, default=str)
|
| 218 |
+
log.info(f" wrote dataset_metadata.json")
|
| 219 |
+
|
| 220 |
+
return md
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _write_shard(rows: list[tuple], path: str):
|
| 224 |
+
"""Write a list of (subject_id, time, code, numeric_value, text_value) to parquet."""
|
| 225 |
+
if not rows: return
|
| 226 |
+
# Build columnar arrays
|
| 227 |
+
subject_id = pa.array([r[0] for r in rows], type=pa.int64())
|
| 228 |
+
time = pa.array([r[1] for r in rows], type=pa.timestamp("us"))
|
| 229 |
+
code = pa.array([r[2] for r in rows], type=pa.string())
|
| 230 |
+
numeric_value = pa.array([r[3] for r in rows], type=pa.float32())
|
| 231 |
+
text_value = pa.array([r[4] for r in rows], type=pa.large_string())
|
| 232 |
+
table = pa.Table.from_arrays(
|
| 233 |
+
[subject_id, time, code, numeric_value, text_value],
|
| 234 |
+
names=["subject_id", "time", "code", "numeric_value", "text_value"],
|
| 235 |
+
)
|
| 236 |
+
# Validate against MEDS schema
|
| 237 |
+
expected_schema = meds.DataSchema.schema()
|
| 238 |
+
# Cast if needed
|
| 239 |
+
table = table.cast(expected_schema, safe=False)
|
| 240 |
+
pq.write_table(table, path, compression="zstd")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Brazilian-specific mapping tables (extend as needed)
|
| 244 |
+
ICD10_CHAPTERS = {
|
| 245 |
+
"A": "Certain infectious and parasitic diseases",
|
| 246 |
+
"B": "Certain infectious and parasitic diseases",
|
| 247 |
+
"C": "Neoplasms",
|
| 248 |
+
"D": "Neoplasms / Diseases of the blood and immune",
|
| 249 |
+
"E": "Endocrine, nutritional and metabolic diseases",
|
| 250 |
+
"F": "Mental, Behavioral and Neurodevelopmental disorders",
|
| 251 |
+
"G": "Diseases of the nervous system",
|
| 252 |
+
"H": "Diseases of the eye / ear",
|
| 253 |
+
"I": "Diseases of the circulatory system",
|
| 254 |
+
"J": "Diseases of the respiratory system",
|
| 255 |
+
"K": "Diseases of the digestive system",
|
| 256 |
+
"L": "Diseases of the skin and subcutaneous tissue",
|
| 257 |
+
"M": "Diseases of the musculoskeletal system",
|
| 258 |
+
"N": "Diseases of the genitourinary system",
|
| 259 |
+
"O": "Pregnancy, childbirth and the puerperium",
|
| 260 |
+
"P": "Certain conditions originating in the perinatal period",
|
| 261 |
+
"Q": "Congenital malformations, deformations and chromosomal abnormalities",
|
| 262 |
+
"R": "Symptoms, signs and abnormal clinical and laboratory findings",
|
| 263 |
+
"S": "Injury, poisoning and certain other consequences of external causes",
|
| 264 |
+
"T": "Injury, poisoning and certain other consequences of external causes",
|
| 265 |
+
"V": "External causes of morbidity",
|
| 266 |
+
"W": "External causes of morbidity",
|
| 267 |
+
"X": "External causes of morbidity",
|
| 268 |
+
"Y": "External causes of morbidity",
|
| 269 |
+
"Z": "Factors influencing health status and contact with health services",
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _get_description(code: str, count: int) -> str:
|
| 274 |
+
"""Generate a brief description for a code (used in codes.parquet)."""
|
| 275 |
+
if code in ("MEDS_BIRTH",): return "Birth event (reserved)"
|
| 276 |
+
if code in ("MEDS_DEATH",): return "Death event (reserved)"
|
| 277 |
+
parts = code.split("//")
|
| 278 |
+
if len(parts) < 2: return f"Unknown code (n={count})"
|
| 279 |
+
domain, val = parts[0], "//".join(parts[1:])
|
| 280 |
+
if domain == "GENDER": return f"Patient sex = {val}"
|
| 281 |
+
if domain == "ORPHA": return f"Orphanet rare disease {val}"
|
| 282 |
+
if domain == "ICD10":
|
| 283 |
+
ch = ICD10_CHAPTERS.get(val[0], "Unknown chapter")
|
| 284 |
+
return f"ICD-10 {val} ({ch})"
|
| 285 |
+
if domain == "SIH": return f"SIH hospitalization {val}"
|
| 286 |
+
if domain == "Visit": return f"Visit type {val}"
|
| 287 |
+
if domain == "SIGTAP": return f"SIGTAP procedure {val}"
|
| 288 |
+
if domain == "APAC": return f"APAC orphan-drug authorization {val}"
|
| 289 |
+
if domain == "BPAI": return f"BPA-I outpatient procedure {val}"
|
| 290 |
+
if domain == "UF": return f"Residence UF {val}"
|
| 291 |
+
return f"{domain} code {val}"
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _get_parent_codes(code: str) -> list[str]:
|
| 295 |
+
"""Return parent codes for ontology hierarchy (currently minimal)."""
|
| 296 |
+
parts = code.split("//")
|
| 297 |
+
if len(parts) < 2: return []
|
| 298 |
+
domain, val = parts[0], "//".join(parts[1:])
|
| 299 |
+
parents = []
|
| 300 |
+
if domain == "ICD10" and len(val) >= 3:
|
| 301 |
+
# ICD-10 chapter as parent
|
| 302 |
+
chapter = val[0]
|
| 303 |
+
if chapter in ICD10_CHAPTERS:
|
| 304 |
+
parents.append(f"ICD10//chapter_{chapter}")
|
| 305 |
+
# 3-char prefix as parent (e.g., E84.0 → E84)
|
| 306 |
+
if "." in val:
|
| 307 |
+
parents.append(f"ICD10//{val.split('.')[0]}")
|
| 308 |
+
elif len(val) > 3:
|
| 309 |
+
parents.append(f"ICD10//{val[:3]}")
|
| 310 |
+
if domain == "SIGTAP" and len(val) >= 4:
|
| 311 |
+
# 4-digit group as parent (SIGTAP 10-digit → 4-digit group)
|
| 312 |
+
parents.append(f"SIGTAP//group_{val[:4]}")
|
| 313 |
+
return parents
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def load_meds_dataset(meds_dir: str) -> dict:
|
| 317 |
+
"""Load a MEDS dataset back from parquet for inspection or downstream processing."""
|
| 318 |
+
import glob
|
| 319 |
+
shards = sorted(glob.glob(f"{meds_dir}/data/*.parquet"))
|
| 320 |
+
tables = [pq.read_table(p) for p in shards]
|
| 321 |
+
data = pa.concat_tables(tables) if tables else None
|
| 322 |
+
codes = pq.read_table(f"{meds_dir}/metadata/codes.parquet")
|
| 323 |
+
md = json.load(open(f"{meds_dir}/dataset_metadata.json"))
|
| 324 |
+
return {"data": data, "codes": codes, "metadata": md}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
# Quick test on real patient data
|
| 329 |
+
logging.basicConfig(level=logging.INFO,
|
| 330 |
+
format="%(asctime)s %(levelname)s %(message)s")
|
| 331 |
+
PATIENTS = "/tmp/datasus_patient_trajectories_v2.json"
|
| 332 |
+
if os.path.exists(PATIENTS):
|
| 333 |
+
patients = json.load(open(PATIENTS))[:50] # 50 patients smoke test
|
| 334 |
+
md = export_to_meds(patients, "/tmp/meds_smoke_test")
|
| 335 |
+
print("\n=== smoke test result ===")
|
| 336 |
+
print(json.dumps(md, indent=2, default=str))
|
reference_impl/primekg_attention.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.
|
| 2 |
+
|
| 3 |
+
Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
|
| 4 |
+
- disease → has_phenotype → phenotype (curated phenotype linkage)
|
| 5 |
+
- disease → associated_with → gene (causal gene evidence)
|
| 6 |
+
- gene → interacts_with → gene (PPI network)
|
| 7 |
+
- phenotype → is_a → phenotype (HPO ontology)
|
| 8 |
+
|
| 9 |
+
Ego-subgraph BFS:
|
| 10 |
+
1. Start from disease node (ORPHA → PrimeKG index)
|
| 11 |
+
2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
|
| 12 |
+
3. 1-hop: pull connected genes
|
| 13 |
+
4. 2-hop: gene→gene neighbors (interacting partners)
|
| 14 |
+
5. Concatenate fused embeddings of all selected nodes → cross-attn context
|
| 15 |
+
|
| 16 |
+
Falls back to cosine-similarity if graph not loaded.
|
| 17 |
+
|
| 18 |
+
White-space architecture (May 2026):
|
| 19 |
+
- EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
|
| 20 |
+
- PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
|
| 21 |
+
- We use it for GENERATION (counterfactual trajectory completion)
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
import os
|
| 25 |
+
import json
|
| 26 |
+
import logging
|
| 27 |
+
from functools import lru_cache
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
log = logging.getLogger("gemeo.cdf.kg")
|
| 35 |
+
|
| 36 |
+
# Try raras-app paths first (richer, including hetero_graph edges + node_texts)
|
| 37 |
+
RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
|
| 38 |
+
LOCAL_KG_DIR = os.path.join(
|
| 39 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _kg_path(name: str) -> str:
|
| 43 |
+
"""Prefer raras-app path if available, fall back to local fp16."""
|
| 44 |
+
raras = os.path.join(RARAS_KG_DIR, name)
|
| 45 |
+
if os.path.exists(raras):
|
| 46 |
+
return raras
|
| 47 |
+
local = os.path.join(LOCAL_KG_DIR, name)
|
| 48 |
+
return local if os.path.exists(local) else None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(maxsize=1)
|
| 52 |
+
def load_kg(prefer_raras: bool = True) -> dict | None:
|
| 53 |
+
"""Load PrimeKG: fused embeddings + node ids + edges + texts.
|
| 54 |
+
|
| 55 |
+
Returns dict:
|
| 56 |
+
emb : {kind: torch.Tensor(N, 3072)}
|
| 57 |
+
idx2id : {kind: {pos: id_str}}
|
| 58 |
+
id2idx : {kind: {id_str: pos}}
|
| 59 |
+
edges : {edge_type: {'src': [...], 'dst': [...]}}
|
| 60 |
+
adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed
|
| 61 |
+
texts : {kind: [str, ...]} -- aligned to position
|
| 62 |
+
num_nodes : {kind: int}
|
| 63 |
+
"""
|
| 64 |
+
# Try raras-app full file first, then local fp16
|
| 65 |
+
emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
|
| 66 |
+
if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
|
| 67 |
+
else _kg_path("fused_embeddings_fp16.npz"))
|
| 68 |
+
if not emb_path or not os.path.exists(emb_path):
|
| 69 |
+
log.warning("PrimeKG fused embeddings not found")
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
nids_path = _kg_path("node_ids.json")
|
| 73 |
+
graph_path = _kg_path("hetero_graph.json")
|
| 74 |
+
texts_path = _kg_path("node_texts.json")
|
| 75 |
+
|
| 76 |
+
fz = np.load(emb_path)
|
| 77 |
+
nids = json.load(open(nids_path)) if nids_path else {}
|
| 78 |
+
graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
|
| 79 |
+
texts = json.load(open(texts_path)) if texts_path else {}
|
| 80 |
+
|
| 81 |
+
out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
|
| 82 |
+
"texts": texts, "num_nodes": graph.get("num_nodes", {})}
|
| 83 |
+
|
| 84 |
+
for kind in ("disease", "phenotype", "gene"):
|
| 85 |
+
if kind in fz.files:
|
| 86 |
+
out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
|
| 87 |
+
if kind in nids:
|
| 88 |
+
out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
|
| 89 |
+
out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}
|
| 90 |
+
|
| 91 |
+
# Build adjacency from edges
|
| 92 |
+
for edge_type, edata in graph.get("edges", {}).items():
|
| 93 |
+
adj = {}
|
| 94 |
+
srcs = edata.get("src", []) if isinstance(edata, dict) else []
|
| 95 |
+
dsts = edata.get("dst", []) if isinstance(edata, dict) else []
|
| 96 |
+
for s, d in zip(srcs, dsts):
|
| 97 |
+
adj.setdefault(int(s), []).append(int(d))
|
| 98 |
+
out["adj"][edge_type] = adj
|
| 99 |
+
out["edges"][edge_type] = edata
|
| 100 |
+
|
| 101 |
+
log.info(f" KG loaded from {emb_path}")
|
| 102 |
+
log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, "
|
| 103 |
+
f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
|
| 104 |
+
f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
|
| 105 |
+
log.info(f" edges: {list(out['edges'].keys())}")
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
|
| 110 |
+
k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
|
| 111 |
+
"""BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).
|
| 112 |
+
|
| 113 |
+
Returns concatenated embeddings (N, 3072) where:
|
| 114 |
+
- 1 disease node (the query)
|
| 115 |
+
- up to k_pheno phenotype nodes (direct edges)
|
| 116 |
+
- up to k_gene gene nodes (direct edges)
|
| 117 |
+
- up to k_gene_2hop gene-gene 2-hop neighbors
|
| 118 |
+
|
| 119 |
+
Falls back to cosine similarity if no edges available.
|
| 120 |
+
"""
|
| 121 |
+
if kg is None:
|
| 122 |
+
kg = load_kg()
|
| 123 |
+
if kg is None or "disease" not in kg["emb"]:
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
d_id = kg["id2idx"]["disease"].get(str(orpha_code))
|
| 127 |
+
if d_id is None:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
d_emb = kg["emb"]["disease"][d_id]
|
| 131 |
+
nodes = [d_emb.unsqueeze(0)]
|
| 132 |
+
|
| 133 |
+
# Phenotype neighbors (via disease__has_phenotype__phenotype)
|
| 134 |
+
adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
|
| 135 |
+
pheno_neighbors = adj.get(d_id, [])
|
| 136 |
+
if pheno_neighbors and "phenotype" in kg["emb"]:
|
| 137 |
+
pheno_neighbors = pheno_neighbors[:k_pheno]
|
| 138 |
+
nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
|
| 139 |
+
elif "phenotype" in kg["emb"]:
|
| 140 |
+
# Fallback: cosine similarity
|
| 141 |
+
pool = kg["emb"]["phenotype"]
|
| 142 |
+
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
|
| 143 |
+
top = sim.topk(min(k_pheno, pool.size(0))).indices
|
| 144 |
+
nodes.append(pool[top])
|
| 145 |
+
|
| 146 |
+
# Gene neighbors (via disease__associated_with__gene)
|
| 147 |
+
g_adj = kg["adj"].get("disease__associated_with__gene", {})
|
| 148 |
+
gene_neighbors = g_adj.get(d_id, [])
|
| 149 |
+
if gene_neighbors and "gene" in kg["emb"]:
|
| 150 |
+
gene_neighbors = gene_neighbors[:k_gene]
|
| 151 |
+
nodes.append(kg["emb"]["gene"][gene_neighbors])
|
| 152 |
+
|
| 153 |
+
# 2-hop: gene-gene neighbors of the genes we just pulled
|
| 154 |
+
if k_gene_2hop > 0:
|
| 155 |
+
gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
|
| 156 |
+
seen = set(gene_neighbors)
|
| 157 |
+
second_hop = []
|
| 158 |
+
for g in gene_neighbors:
|
| 159 |
+
for g2 in gg_adj.get(g, []):
|
| 160 |
+
if g2 not in seen:
|
| 161 |
+
second_hop.append(g2)
|
| 162 |
+
seen.add(g2)
|
| 163 |
+
if len(second_hop) >= k_gene_2hop: break
|
| 164 |
+
if len(second_hop) >= k_gene_2hop: break
|
| 165 |
+
if second_hop:
|
| 166 |
+
nodes.append(kg["emb"]["gene"][second_hop])
|
| 167 |
+
elif "gene" in kg["emb"]:
|
| 168 |
+
pool = kg["emb"]["gene"]
|
| 169 |
+
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
|
| 170 |
+
top = sim.topk(min(k_gene, pool.size(0))).indices
|
| 171 |
+
nodes.append(pool[top])
|
| 172 |
+
|
| 173 |
+
return torch.cat(nodes, dim=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Keep old API name for backward compat
|
| 177 |
+
ego_subgraph = ego_subgraph_real
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class KGCrossAttention(nn.Module):
|
| 181 |
+
"""Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
|
| 182 |
+
def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.n_heads = n_heads
|
| 185 |
+
self.head_dim = d_model // n_heads
|
| 186 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 187 |
+
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
|
| 188 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 189 |
+
self.norm_q = nn.LayerNorm(d_model)
|
| 190 |
+
self.norm_kv = nn.LayerNorm(d_model)
|
| 191 |
+
self.dropout = nn.Dropout(dropout)
|
| 192 |
+
|
| 193 |
+
def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
B, T, D = x_seq.shape
|
| 195 |
+
_, N, _ = x_kg.shape
|
| 196 |
+
q = self.q_proj(self.norm_q(x_seq))
|
| 197 |
+
kv = self.kv_proj(self.norm_kv(x_kg))
|
| 198 |
+
k, v = kv.chunk(2, dim=-1)
|
| 199 |
+
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 200 |
+
k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
|
| 201 |
+
v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
|
| 202 |
+
out = F.scaled_dot_product_attention(
|
| 203 |
+
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
|
| 204 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 205 |
+
return x_seq + self.dropout(self.out_proj(out))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class KGProjector(nn.Module):
|
| 209 |
+
"""Project 3072-d KG embeddings to d_model with LayerNorm."""
|
| 210 |
+
def __init__(self, kg_dim: int, d_model: int):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.proj = nn.Sequential(
|
| 213 |
+
nn.Linear(kg_dim, d_model),
|
| 214 |
+
nn.GELU(),
|
| 215 |
+
nn.LayerNorm(d_model),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
return self.proj(x)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def build_kg_batch(orpha_strings: list[str], d_model: int,
|
| 223 |
+
projector: KGProjector,
|
| 224 |
+
k_pheno: int = 16, k_gene: int = 16,
|
| 225 |
+
k_gene_2hop: int = 0) -> torch.Tensor:
|
| 226 |
+
"""Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.
|
| 227 |
+
|
| 228 |
+
Falls back to zero context for missing ORPHAs.
|
| 229 |
+
"""
|
| 230 |
+
kg = load_kg()
|
| 231 |
+
if kg is None:
|
| 232 |
+
return torch.zeros(len(orpha_strings), 1, d_model,
|
| 233 |
+
device=next(projector.parameters()).device)
|
| 234 |
+
N = 1 + k_pheno + k_gene + k_gene_2hop
|
| 235 |
+
egos = []
|
| 236 |
+
for orpha in orpha_strings:
|
| 237 |
+
e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
|
| 238 |
+
if e is None:
|
| 239 |
+
e = torch.zeros(N, kg["emb"]["disease"].size(-1))
|
| 240 |
+
elif e.size(0) < N:
|
| 241 |
+
pad = torch.zeros(N - e.size(0), e.size(-1))
|
| 242 |
+
e = torch.cat([e, pad], dim=0)
|
| 243 |
+
egos.append(e[:N])
|
| 244 |
+
egos = torch.stack(egos, dim=0)
|
| 245 |
+
return projector(egos.to(next(projector.parameters()).device))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
|
| 249 |
+
k_pheno: int = 16, k_gene: int = 16,
|
| 250 |
+
batch_size: int = 32) -> torch.Tensor:
|
| 251 |
+
"""Pre-compute KG context for an entire dataset in batches.
|
| 252 |
+
|
| 253 |
+
Returns (N_patients, kg_nodes, d_model) tensor on projector device.
|
| 254 |
+
Saves to disk-cacheable format.
|
| 255 |
+
"""
|
| 256 |
+
out = []
|
| 257 |
+
for i in range(0, len(orpha_codes), batch_size):
|
| 258 |
+
batch = orpha_codes[i:i + batch_size]
|
| 259 |
+
ctx = build_kg_batch(batch, projector.proj[0].out_features,
|
| 260 |
+
projector, k_pheno, k_gene)
|
| 261 |
+
out.append(ctx.cpu())
|
| 262 |
+
return torch.cat(out, dim=0)
|
reference_impl/sample.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
|
| 2 |
+
|
| 3 |
+
Diffusion Forcing flexibility — the same model handles:
|
| 4 |
+
|
| 5 |
+
AR mode:
|
| 6 |
+
Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
|
| 7 |
+
transformer but with per-token noise control.
|
| 8 |
+
|
| 9 |
+
Denoise mode (bidirectional):
|
| 10 |
+
Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
|
| 11 |
+
|
| 12 |
+
Counterfactual mode (the TTE primitive):
|
| 13 |
+
Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
|
| 14 |
+
generate. Condition on (cohort, intervention_action_id). Sample N times,
|
| 15 |
+
compare distributions of outcome tokens.
|
| 16 |
+
|
| 17 |
+
CFG (classifier-free guidance) wraps any mode:
|
| 18 |
+
logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
|
| 19 |
+
|
| 20 |
+
Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
|
| 21 |
+
distilled student model — implemented in distill.py.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
from .diffusion_forcing import CDFTransformer
|
| 30 |
+
|
| 31 |
+
log = logging.getLogger("gemeo.cdf.sample")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def sample_denoise(
|
| 36 |
+
model: CDFTransformer,
|
| 37 |
+
cond: torch.Tensor,
|
| 38 |
+
*,
|
| 39 |
+
seed_prefix: torch.Tensor | None = None,
|
| 40 |
+
observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
|
| 41 |
+
action: torch.Tensor | None = None,
|
| 42 |
+
gamma: float = 2.0,
|
| 43 |
+
n_steps: int = 32,
|
| 44 |
+
null_cond: int = 0,
|
| 45 |
+
schedule: str = "cosine",
|
| 46 |
+
) -> torch.Tensor:
|
| 47 |
+
"""Denoise-mode sampling: fully-masked sequence + iterative refinement.
|
| 48 |
+
|
| 49 |
+
Supports:
|
| 50 |
+
- seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
|
| 51 |
+
- observed_mask: arbitrary positions to clamp (counterfactual mode)
|
| 52 |
+
- CFG via (cond, null_cond) pair
|
| 53 |
+
"""
|
| 54 |
+
cfg = model.cfg
|
| 55 |
+
device = cond.device
|
| 56 |
+
B = cond.size(0)
|
| 57 |
+
T = cfg.max_seq_len
|
| 58 |
+
|
| 59 |
+
# Init with MASK
|
| 60 |
+
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
|
| 61 |
+
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
|
| 62 |
+
if seed_prefix is not None:
|
| 63 |
+
L = seed_prefix.size(1)
|
| 64 |
+
x[:, :L] = seed_prefix
|
| 65 |
+
fixed_mask[:, :L] = True
|
| 66 |
+
if observed_mask is not None:
|
| 67 |
+
fixed_mask |= observed_mask
|
| 68 |
+
|
| 69 |
+
# Build noise schedule
|
| 70 |
+
if schedule == "cosine":
|
| 71 |
+
# smooth cosine from 1 -> 0
|
| 72 |
+
ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
|
| 73 |
+
else:
|
| 74 |
+
ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
|
| 75 |
+
|
| 76 |
+
null = torch.full_like(cond, null_cond)
|
| 77 |
+
null_action = (torch.full_like(action, cfg.n_latent_actions)
|
| 78 |
+
if action is not None and cfg.use_latent_action else None)
|
| 79 |
+
|
| 80 |
+
for k in range(n_steps):
|
| 81 |
+
# Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
|
| 82 |
+
sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
|
| 83 |
+
torch.full((B, T), ts[k].item(), device=device))
|
| 84 |
+
logits_c = model(x, sigma, cond, action)
|
| 85 |
+
if gamma > 0:
|
| 86 |
+
logits_n = model(x, sigma, null, null_action)
|
| 87 |
+
logits = (1 + gamma) * logits_c - gamma * logits_n
|
| 88 |
+
else:
|
| 89 |
+
logits = logits_c
|
| 90 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 91 |
+
|
| 92 |
+
probs = F.softmax(logits, dim=-1)
|
| 93 |
+
confs, preds = probs.max(dim=-1)
|
| 94 |
+
|
| 95 |
+
# Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
|
| 96 |
+
t_next = ts[k+1].item()
|
| 97 |
+
target_kept = int(round((1 - t_next) * T))
|
| 98 |
+
revealed = (x != cfg.mask_token) | fixed_mask
|
| 99 |
+
already = revealed.sum(dim=-1)
|
| 100 |
+
new_x = x.clone()
|
| 101 |
+
for b in range(B):
|
| 102 |
+
need = max(0, target_kept - int(already[b].item()))
|
| 103 |
+
if need == 0:
|
| 104 |
+
continue
|
| 105 |
+
confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
|
| 106 |
+
topi = confs_b.topk(need).indices
|
| 107 |
+
new_x[b, topi] = preds[b, topi]
|
| 108 |
+
x = new_x
|
| 109 |
+
|
| 110 |
+
# Final cleanup
|
| 111 |
+
mask_left = x == cfg.mask_token
|
| 112 |
+
if mask_left.any():
|
| 113 |
+
sigma_final = torch.zeros(B, T, device=device)
|
| 114 |
+
logits_c = model(x, sigma_final, cond, action)
|
| 115 |
+
if gamma > 0:
|
| 116 |
+
logits_n = model(x, sigma_final, null, null_action)
|
| 117 |
+
logits = (1 + gamma) * logits_c - gamma * logits_n
|
| 118 |
+
else:
|
| 119 |
+
logits = logits_c
|
| 120 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 121 |
+
preds = logits.argmax(-1)
|
| 122 |
+
x = torch.where(mask_left, preds, x)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def sample_ar(
|
| 128 |
+
model: CDFTransformer,
|
| 129 |
+
cond: torch.Tensor,
|
| 130 |
+
prefix: torch.Tensor,
|
| 131 |
+
*,
|
| 132 |
+
action: torch.Tensor | None = None,
|
| 133 |
+
max_new: int = 50,
|
| 134 |
+
temperature: float = 1.0,
|
| 135 |
+
gamma: float = 0.0,
|
| 136 |
+
null_cond: int = 0,
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""AR-mode sampling: future tokens at sigma=1, past at sigma=0.
|
| 139 |
+
|
| 140 |
+
Faster than denoise mode when you only want to continue a prefix.
|
| 141 |
+
"""
|
| 142 |
+
cfg = model.cfg
|
| 143 |
+
device = cond.device
|
| 144 |
+
B = cond.size(0)
|
| 145 |
+
x = prefix.clone().to(device)
|
| 146 |
+
if x.dim() == 1: x = x.unsqueeze(0)
|
| 147 |
+
null = torch.full_like(cond, null_cond)
|
| 148 |
+
null_action = (torch.full_like(action, cfg.n_latent_actions)
|
| 149 |
+
if action is not None and cfg.use_latent_action else None)
|
| 150 |
+
|
| 151 |
+
for _ in range(max_new):
|
| 152 |
+
T_now = x.size(1)
|
| 153 |
+
if T_now >= cfg.max_seq_len:
|
| 154 |
+
break
|
| 155 |
+
# Pad with MASK
|
| 156 |
+
x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
|
| 157 |
+
device=device, dtype=torch.long)], dim=1)
|
| 158 |
+
sigma = torch.zeros(B, T_now + 1, device=device)
|
| 159 |
+
sigma[:, -1] = 1.0
|
| 160 |
+
a_pad = None
|
| 161 |
+
if action is not None and cfg.use_latent_action:
|
| 162 |
+
a_pad = torch.cat([action[:, :T_now],
|
| 163 |
+
torch.full((B, 1), cfg.n_latent_actions,
|
| 164 |
+
device=device, dtype=torch.long)], dim=1)
|
| 165 |
+
logits = model(x_pad, sigma, cond, a_pad)
|
| 166 |
+
if gamma > 0:
|
| 167 |
+
logits_n = model(x_pad, sigma, null, null_action)
|
| 168 |
+
logits = (1 + gamma) * logits - gamma * logits_n
|
| 169 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 170 |
+
p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
|
| 171 |
+
nxt = torch.multinomial(p, 1)
|
| 172 |
+
x = torch.cat([x, nxt], dim=1)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@torch.no_grad()
|
| 177 |
+
def counterfactual_rollout(
|
| 178 |
+
model: CDFTransformer,
|
| 179 |
+
seed_prefix: torch.Tensor,
|
| 180 |
+
treatment_cond: int,
|
| 181 |
+
untreated_cond: int,
|
| 182 |
+
*,
|
| 183 |
+
treatment_action: int | None = None,
|
| 184 |
+
untreated_action: int | None = None,
|
| 185 |
+
n_samples: int = 100,
|
| 186 |
+
gamma: float = 2.0,
|
| 187 |
+
n_steps: int = 32,
|
| 188 |
+
) -> dict:
|
| 189 |
+
"""Sample paired counterfactual trajectories under treatment vs no-treatment.
|
| 190 |
+
|
| 191 |
+
Two ways to specify the intervention:
|
| 192 |
+
- via cond id (cohort-level): treatment_cond / untreated_cond
|
| 193 |
+
- via latent action id (per-token): treatment_action / untreated_action
|
| 194 |
+
"""
|
| 195 |
+
cfg = model.cfg
|
| 196 |
+
device = next(model.parameters()).device
|
| 197 |
+
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
|
| 198 |
+
T = cfg.max_seq_len
|
| 199 |
+
|
| 200 |
+
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
|
| 201 |
+
cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
|
| 202 |
+
|
| 203 |
+
action_tx = action_null = None
|
| 204 |
+
if cfg.use_latent_action:
|
| 205 |
+
action_tx = torch.full((n_samples, T),
|
| 206 |
+
treatment_action if treatment_action is not None
|
| 207 |
+
else cfg.n_latent_actions,
|
| 208 |
+
device=device, dtype=torch.long)
|
| 209 |
+
action_null = torch.full((n_samples, T),
|
| 210 |
+
untreated_action if untreated_action is not None
|
| 211 |
+
else cfg.n_latent_actions,
|
| 212 |
+
device=device, dtype=torch.long)
|
| 213 |
+
|
| 214 |
+
traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
|
| 215 |
+
action=action_tx, gamma=gamma, n_steps=n_steps)
|
| 216 |
+
traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
|
| 217 |
+
action=action_null, gamma=gamma, n_steps=n_steps)
|
| 218 |
+
return {
|
| 219 |
+
"traj_treated": traj_tx, "traj_untreated": traj_null,
|
| 220 |
+
"n": n_samples, "treatment_cond": treatment_cond,
|
| 221 |
+
"untreated_cond": untreated_cond, "gamma": gamma,
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
|
| 226 |
+
if not target_ids:
|
| 227 |
+
return 0.0
|
| 228 |
+
target = torch.tensor(target_ids, device=traj.device)
|
| 229 |
+
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
|
| 230 |
+
return has.float().mean().item()
|
reference_impl/wsd_scheduler.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.
|
| 2 |
+
|
| 3 |
+
Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
|
| 4 |
+
Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
|
| 5 |
+
Phase 2 (stable, 60-80%): constant peak_lr
|
| 6 |
+
Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1
|
| 7 |
+
|
| 8 |
+
Beats cosine for:
|
| 9 |
+
- data-limited regimes (we can extend stable phase if loss still falls)
|
| 10 |
+
- continue-pretrain (sharp decay enables clean fine-tune handoff)
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def wsd_lr_schedule(step: int, total_steps: int,
|
| 19 |
+
warmup_steps: int = 500,
|
| 20 |
+
stable_frac: float = 0.80,
|
| 21 |
+
decay_frac: float = 0.15,
|
| 22 |
+
min_lr_ratio: float = 0.1,
|
| 23 |
+
decay_type: str = "linear") -> float:
|
| 24 |
+
"""Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
|
| 25 |
+
if step < warmup_steps:
|
| 26 |
+
return step / max(1, warmup_steps)
|
| 27 |
+
# remainder of steps after warmup
|
| 28 |
+
remaining = total_steps - warmup_steps
|
| 29 |
+
if remaining <= 0:
|
| 30 |
+
return 1.0
|
| 31 |
+
stable_steps = int(stable_frac * remaining)
|
| 32 |
+
decay_steps = int(decay_frac * remaining)
|
| 33 |
+
pos = step - warmup_steps
|
| 34 |
+
if pos < stable_steps:
|
| 35 |
+
return 1.0
|
| 36 |
+
decay_pos = pos - stable_steps
|
| 37 |
+
if decay_pos >= decay_steps:
|
| 38 |
+
return min_lr_ratio
|
| 39 |
+
progress = decay_pos / max(1, decay_steps)
|
| 40 |
+
if decay_type == "linear":
|
| 41 |
+
return 1.0 - (1.0 - min_lr_ratio) * progress
|
| 42 |
+
elif decay_type == "cosine":
|
| 43 |
+
return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
|
| 44 |
+
elif decay_type == "inv_sqrt":
|
| 45 |
+
return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"unknown decay_type: {decay_type}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
|
| 51 |
+
total_steps: int,
|
| 52 |
+
warmup_steps: int = 500,
|
| 53 |
+
stable_frac: float = 0.80,
|
| 54 |
+
decay_frac: float = 0.15,
|
| 55 |
+
min_lr_ratio: float = 0.1,
|
| 56 |
+
decay_type: str = "linear") -> LambdaLR:
|
| 57 |
+
"""Build a LambdaLR scheduler with WSD schedule."""
|
| 58 |
+
def fn(step):
|
| 59 |
+
return wsd_lr_schedule(step, total_steps, warmup_steps,
|
| 60 |
+
stable_frac, decay_frac, min_lr_ratio, decay_type)
|
| 61 |
+
return LambdaLR(optimizer, lr_lambda=fn)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
# Visualize the schedule
|
| 66 |
+
total = 10000
|
| 67 |
+
warmup = 500
|
| 68 |
+
print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
|
| 69 |
+
print(f" step lr_mult")
|
| 70 |
+
for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
|
| 71 |
+
m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
|
| 72 |
+
print(f" {s:>5} {m:.4f}")
|