GEMEO/SUS v6 recurrence-aware (RAVEN) — new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.
Browse files- LICENSE +34 -0
- README.md +131 -0
- benchmarks/v4_newonset_eval.json +81 -0
- benchmarks/v6_raven_newonset.json +30 -0
- cdf_v6_raven.pt +3 -0
- src/__init__.py +33 -0
- src/adaln_zero.py +148 -0
- src/diffusion_forcing_v13.py +314 -0
- src/eval_sota.py +290 -0
- src/meds_export.py +336 -0
- src/primekg_attention.py +262 -0
- src/sample.py +230 -0
- src/wsd_scheduler.py +72 -0
LICENSE
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Raras.ai / RarasNet
|
| 4 |
+
Authors: Dimas Quintas Verdial and contributors.
|
| 5 |
+
|
| 6 |
+
You are free to:
|
| 7 |
+
- Share — copy and redistribute the material in any medium or format
|
| 8 |
+
- Adapt — remix, transform, and build upon the material
|
| 9 |
+
|
| 10 |
+
Under the following terms:
|
| 11 |
+
- Attribution — You must give appropriate credit, provide a link to the
|
| 12 |
+
license, and indicate if changes were made.
|
| 13 |
+
- NonCommercial — You may not use the material for commercial purposes.
|
| 14 |
+
- No additional restrictions — You may not apply legal terms or
|
| 15 |
+
technological measures that legally restrict others from doing anything
|
| 16 |
+
the license permits.
|
| 17 |
+
|
| 18 |
+
Full legal text: https://creativecommons.org/licenses/by-nc/4.0/legalcode
|
| 19 |
+
|
| 20 |
+
NOT FOR CLINICAL USE
|
| 21 |
+
====================
|
| 22 |
+
This model is released for research purposes only. It is NOT a medical
|
| 23 |
+
device. It is NOT approved by ANVISA, FDA, EMA, or any regulatory body.
|
| 24 |
+
Outputs MUST NOT be used to inform diagnosis, treatment, or any clinical
|
| 25 |
+
decision without explicit human physician oversight and applicable
|
| 26 |
+
regulatory clearance.
|
| 27 |
+
|
| 28 |
+
Compliance scope (Brazilian SUS data):
|
| 29 |
+
- LGPD (Lei Geral de Proteção de Dados, Brazil)
|
| 30 |
+
- CNS-hash linkage performed under data-use agreement with DATASUS
|
| 31 |
+
- Resolution CNS 466/2012 + 510/2016 (Brazilian ethics framework)
|
| 32 |
+
|
| 33 |
+
For commercial licensing or clinical deployment partnerships, contact:
|
| 34 |
+
dimas@raras.ai
|
README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
language: [pt, en]
|
| 4 |
+
tags:
|
| 5 |
+
- world-model
|
| 6 |
+
- patient-digital-twin
|
| 7 |
+
- rare-disease
|
| 8 |
+
- diffusion-forcing
|
| 9 |
+
- recurrence-aware
|
| 10 |
+
- new-onset-prediction
|
| 11 |
+
- brazilian-sus
|
| 12 |
+
- datasus
|
| 13 |
+
- primekg
|
| 14 |
+
library_name: pytorch
|
| 15 |
+
pipeline_tag: time-series-forecasting
|
| 16 |
+
extra_gated_prompt: >-
|
| 17 |
+
Research only. Not a medical device. No clinical use without physician
|
| 18 |
+
oversight and applicable regulatory clearance.
|
| 19 |
+
extra_gated_fields:
|
| 20 |
+
Name: text
|
| 21 |
+
Affiliation: text
|
| 22 |
+
Intended use: text
|
| 23 |
+
I agree to non-clinical research use only: checkbox
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
# GEMEO/SUS v6 — Recurrence-Aware World Model (defeats the autocorrelation trap)
|
| 27 |
+
|
| 28 |
+
> The flagship recurrence-aware instance of [**GEMEO Architecture v2.0**](https://huggingface.co/Raras-AI/gemeo-arch).
|
| 29 |
+
> Implements Principle 7 (RAVEN recurrence-weighted loss) and is the first
|
| 30 |
+
> GEMEO instance to **beat a frequency baseline on the genuinely hard
|
| 31 |
+
> new-onset task** — predicting the *first* occurrence of a clinical event,
|
| 32 |
+
> with repeats excluded.
|
| 33 |
+
|
| 34 |
+
**Family:** [`gemeo-arch`](https://huggingface.co/Raras-AI/gemeo-arch) (architecture v2.0) · [`gemeo-sus-v4`](https://huggingface.co/Raras-AI/gemeo-sus-v4) (predecessor) · **`gemeo-sus-v6`** (this, recurrence-aware flagship) · [`gemeo-twin-stack`](https://huggingface.co/Raras-AI/gemeo-twin-stack)
|
| 35 |
+
|
| 36 |
+
## Why this model exists — the autocorrelation trap
|
| 37 |
+
|
| 38 |
+
GEMEO/SUS v4 reported gap-fill Top-10 = 100%. Rigorous re-evaluation showed
|
| 39 |
+
this was a **metric artifact**: in Brazilian SUS APAC data, **82.2% of events
|
| 40 |
+
are repeats** (a patient on an orphan drug receives the same monthly dispensing
|
| 41 |
+
code), and only **17.8% are first occurrences**. On the genuinely hard
|
| 42 |
+
**new-onset task** (predict the first occurrence, repeats excluded), v4 scored
|
| 43 |
+
Top-1 **23.6% — below a frequency baseline (38.2%)**. This is the documented
|
| 44 |
+
"repeated event tokens inflate metrics" pitfall (RAVEN, arXiv 2603.24562;
|
| 45 |
+
NEP, arXiv 2509.25591).
|
| 46 |
+
|
| 47 |
+
**v6 fixes it.** Following RAVEN, the training loss scales each token by
|
| 48 |
+
`w = max(λ^count, w_min)` (λ = 0.25), so the first occurrence of an event
|
| 49 |
+
carries full weight while the 12th monthly repeat carries ≈ 0. The model is
|
| 50 |
+
forced to learn novelty, not autocorrelation.
|
| 51 |
+
|
| 52 |
+
## Headline result — new-onset prediction (the reviewer-proof metric)
|
| 53 |
+
|
| 54 |
+
All on the held-out test split (6,374 patients), 95% bootstrap CI, on the
|
| 55 |
+
new-onset subset (first occurrence, repeats excluded; n = 1,730 positions).
|
| 56 |
+
|
| 57 |
+
| Model | new-onset Top-1 | new-onset Top-5 | vs frequency baseline |
|
| 58 |
+
|---|---:|---:|---|
|
| 59 |
+
| Frequency baseline | 38.2% | 90.8% | — |
|
| 60 |
+
| GEMEO/SUS v4 (no recurrence weighting) | 23.6% | 98.4% | **loses** (−14.6 pp) |
|
| 61 |
+
| **GEMEO/SUS v6 (RAVEN λ=0.25)** | **60.1% [57.8, 62.3]** | **98.7%** | **wins (+21.9 pp)** |
|
| 62 |
+
|
| 63 |
+
**+36.5 pp over v4, +21.9 pp over the frequency baseline, non-overlapping CI.**
|
| 64 |
+
This is the decisive evidence that recurrence-aware training defeats the
|
| 65 |
+
autocorrelation trap: on genuinely novel events, the model now beats trivial
|
| 66 |
+
baselines by a wide, statistically clear margin.
|
| 67 |
+
|
| 68 |
+
## Architecture context — GEMEO v2.0
|
| 69 |
+
|
| 70 |
+
GEMEO v2.0 is a three-pillar architecture (Propose → Simulate → Verify):
|
| 71 |
+
|
| 72 |
+
- **Pillar A — Graph Proposer:** KG zero-shot link prediction emits first-onset candidates the patient has never had (TxGNN / PhenoKG style). *Scoped.*
|
| 73 |
+
- **Pillar B — World-Model Scorer (this model):** Diffusion-Forcing trunk + **RAVEN recurrence-weighted loss (Principle 7, proven here)** + competing-risks hazard head (Principle 8, scoped). *Implemented.*
|
| 74 |
+
- **Pillar C — Swarm Verifier:** multi-agent debate with KG evidence paths (DeepRare-style). *Scoped.*
|
| 75 |
+
|
| 76 |
+
GEMEO v2.0 targets **Level 3 (counterfactual rollout)** on the clinical-world-model capability rubric of Liu et al. (NeurIPS 2025, arXiv 2511.16333) and is engineered to close the four gaps that survey identifies (under-specified action spaces, weak interventional validation, incomplete multimodal state, limited calibration). Full spec: [`gemeo-arch/gemeo_architecture_spec_v2.md`](https://huggingface.co/Raras-AI/gemeo-arch).
|
| 77 |
+
|
| 78 |
+
## Training recipe
|
| 79 |
+
|
| 80 |
+
- Warm-start from `gemeo-sus-v4` (positional features).
|
| 81 |
+
- RAVEN history-decay loss: λ = 0.25, w_min = 0.02, applied to the per-token
|
| 82 |
+
diffusion-forcing cross-entropy on corrupted positions.
|
| 83 |
+
- 6 epochs, WSD LR (peak 2e-4), bf16, single H100, ~6 min, ≈ $0.50.
|
| 84 |
+
- 19.97M params (same architecture as v4).
|
| 85 |
+
|
| 86 |
+
## Usage
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
import torch, sys; sys.path.append("src")
|
| 90 |
+
import torch.nn as nn
|
| 91 |
+
from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config
|
| 92 |
+
|
| 93 |
+
class PositionalFeatureEmbed(nn.Module):
|
| 94 |
+
def __init__(self, d):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.age_proj=nn.Linear(1,d//4); self.year_proj=nn.Linear(1,d//4)
|
| 97 |
+
self.pos_proj=nn.Linear(1,d//4); self.combine=nn.Linear(3*(d//4),d); self.norm=nn.LayerNorm(d)
|
| 98 |
+
def forward(self, ages, years, positions):
|
| 99 |
+
a=ages.clamp(0,100)/100; y=(years-2010).clamp(0,20)/20; p=(positions/512).clamp(0,1)
|
| 100 |
+
e=torch.cat([self.age_proj(a.unsqueeze(-1)), self.year_proj(y.unsqueeze(-1)),
|
| 101 |
+
self.pos_proj(p.unsqueeze(-1))], -1)
|
| 102 |
+
return self.norm(self.combine(e))
|
| 103 |
+
|
| 104 |
+
ck = torch.load("cdf_v6_raven.pt", map_location="cpu", weights_only=False)
|
| 105 |
+
cfg = CDFv13Config(**{k:v for k,v in ck["config"].items() if k in CDFv13Config.__dataclass_fields__})
|
| 106 |
+
model = CDFv13Transformer(cfg); model.load_state_dict(ck["model_state"])
|
| 107 |
+
pfe = PositionalFeatureEmbed(cfg.d_model); pfe.load_state_dict(ck["pos_feat_state"])
|
| 108 |
+
print(f"GEMEO/SUS v6 (RAVEN λ={ck['raven_lambda']}) — new-onset Top-1 60.1%")
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
## Honest scope
|
| 112 |
+
|
| 113 |
+
- ✅ **Proven on 52k SUS:** recurrence-aware training defeats autocorrelation (this model).
|
| 114 |
+
- 🔜 **Scoped, feasible on SUS:** KG zero-shot onset proposer (Pillar A), swarm verifier (Pillar C), competing-risks hazard head (Principle 8).
|
| 115 |
+
- 🏥 **Requires Mayo multimodal substrate:** rigorous counterfactual/interventional validation (labs, genomics, imaging, dense timing).
|
| 116 |
+
|
| 117 |
+
## Citation
|
| 118 |
+
|
| 119 |
+
```bibtex
|
| 120 |
+
@misc{gemeo_sus_v6_2026,
|
| 121 |
+
title = {GEMEO/SUS v6: Recurrence-Aware Patient World Model for
|
| 122 |
+
New-Onset Prediction in Rare Disease},
|
| 123 |
+
author = {Verdial, Dimas Quintas and the Raras AI team},
|
| 124 |
+
year = {2026},
|
| 125 |
+
url = {https://huggingface.co/Raras-AI/gemeo-sus-v6},
|
| 126 |
+
note = {GEMEO Architecture v2.0, Principle 7. Beats frequency baseline
|
| 127 |
+
on new-onset (60.1% vs 38.2% Top-1).}
|
| 128 |
+
}
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
⚠️ **Research only.** Not a medical device. No clinical use without physician oversight and applicable regulatory clearance.
|
benchmarks/v4_newonset_eval.json
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"gapfill_all": {
|
| 3 |
+
"top1": [
|
| 4 |
+
0.8634692246203037,
|
| 5 |
+
0.8600839328537171,
|
| 6 |
+
0.8667732480682121
|
| 7 |
+
],
|
| 8 |
+
"top5": [
|
| 9 |
+
0.9987743138822276,
|
| 10 |
+
0.9983746336264322,
|
| 11 |
+
0.9990947242206235
|
| 12 |
+
],
|
| 13 |
+
"n": 37530
|
| 14 |
+
},
|
| 15 |
+
"gapfill_newonset": {
|
| 16 |
+
"top1": [
|
| 17 |
+
0.23641618497109826,
|
| 18 |
+
0.21560693641618497,
|
| 19 |
+
0.2560693641618497
|
| 20 |
+
],
|
| 21 |
+
"top5": [
|
| 22 |
+
0.9838150289017341,
|
| 23 |
+
0.9780346820809248,
|
| 24 |
+
0.9895953757225433
|
| 25 |
+
],
|
| 26 |
+
"top10": [
|
| 27 |
+
1.0,
|
| 28 |
+
1.0,
|
| 29 |
+
1.0
|
| 30 |
+
],
|
| 31 |
+
"n": 1730
|
| 32 |
+
},
|
| 33 |
+
"gapfill_newonset_baseline": {
|
| 34 |
+
"top1": 0.3815028901734104,
|
| 35 |
+
"top5": 0.9080924855491329,
|
| 36 |
+
"top10": 0.9838150289017341,
|
| 37 |
+
"n": 1730
|
| 38 |
+
},
|
| 39 |
+
"multi_horizon_newonset": {
|
| 40 |
+
"1": {
|
| 41 |
+
"n": 83,
|
| 42 |
+
"top5": [
|
| 43 |
+
0.5903614457831325,
|
| 44 |
+
0.4819277108433735,
|
| 45 |
+
0.6987951807228916
|
| 46 |
+
],
|
| 47 |
+
"macro_auroc": 0.5721330802852541,
|
| 48 |
+
"n_classes": 2
|
| 49 |
+
},
|
| 50 |
+
"10": {
|
| 51 |
+
"n": 39,
|
| 52 |
+
"top5": [
|
| 53 |
+
0.7692307692307693,
|
| 54 |
+
0.6410256410256411,
|
| 55 |
+
0.8974358974358975
|
| 56 |
+
],
|
| 57 |
+
"macro_auroc": 0.7896551724137931,
|
| 58 |
+
"n_classes": 1
|
| 59 |
+
},
|
| 60 |
+
"25": {
|
| 61 |
+
"n": 10,
|
| 62 |
+
"top5": [
|
| 63 |
+
0.3,
|
| 64 |
+
0.0,
|
| 65 |
+
0.6
|
| 66 |
+
],
|
| 67 |
+
"macro_auroc": NaN,
|
| 68 |
+
"n_classes": 0
|
| 69 |
+
},
|
| 70 |
+
"50": {
|
| 71 |
+
"n": 1,
|
| 72 |
+
"top5": [
|
| 73 |
+
0.0,
|
| 74 |
+
0.0,
|
| 75 |
+
0.0
|
| 76 |
+
],
|
| 77 |
+
"macro_auroc": NaN,
|
| 78 |
+
"n_classes": 0
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
}
|
benchmarks/v6_raven_newonset.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"raven_lambda": 0.25,
|
| 3 |
+
"newonset_model": {
|
| 4 |
+
"top1": [
|
| 5 |
+
0.6011560693641619,
|
| 6 |
+
0.5780202312138729,
|
| 7 |
+
0.623121387283237
|
| 8 |
+
],
|
| 9 |
+
"top5": [
|
| 10 |
+
0.9867052023121388,
|
| 11 |
+
0.9809248554913295,
|
| 12 |
+
0.991907514450867
|
| 13 |
+
],
|
| 14 |
+
"top10": [
|
| 15 |
+
1.0,
|
| 16 |
+
1.0,
|
| 17 |
+
1.0
|
| 18 |
+
],
|
| 19 |
+
"n": 1730
|
| 20 |
+
},
|
| 21 |
+
"newonset_freq_baseline": {
|
| 22 |
+
"top1": 0.3815028901734104,
|
| 23 |
+
"top5": 0.9080924855491329
|
| 24 |
+
},
|
| 25 |
+
"v4_reference": {
|
| 26 |
+
"newonset_top1": 0.236,
|
| 27 |
+
"newonset_baseline_top1": 0.382
|
| 28 |
+
},
|
| 29 |
+
"verdict": "BEATS baseline"
|
| 30 |
+
}
|
cdf_v6_raven.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:428534584c3f995bfd4af52d59d16572600dcc0f4e0e844f28a838f57d73caa7
|
| 3 |
+
size 79932134
|
src/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GEMEO-CDF: Causal Diffusion Forcing for clinical trajectories.
|
| 2 |
+
|
| 3 |
+
Three "first in medicine" hooks:
|
| 4 |
+
1. DIFFUSION FORCING (Chen MIT NeurIPS 2024 → Dreamer 4 Hafner 2025 backbone)
|
| 5 |
+
— independent per-token noise levels unify AR + diffusion + counterfactual
|
| 6 |
+
in ONE loss. Zero clinical port as of May 2026.
|
| 7 |
+
|
| 8 |
+
2. LATENT ACTION MODEL (Genie / DeepMind 2024)
|
| 9 |
+
— VQ-VAE codebook over (state_t, state_{t+1}) deltas discovers a
|
| 10 |
+
treatment vocabulary without RxNorm/ATC labels. Solves the APAC
|
| 11 |
+
miscoding / sparsity / off-label labelling pain in DATASUS.
|
| 12 |
+
|
| 13 |
+
3. PROCESS REWARD VERIFIER (o3 / MAI-DxO 2025 pattern)
|
| 14 |
+
— small PRM scores top-K rollouts at inference, returns top-1 +
|
| 15 |
+
uncertainty band. Deliberative trajectory generation, novel in EHR.
|
| 16 |
+
|
| 17 |
+
Modules:
|
| 18 |
+
diffusion_forcing.py — core architecture (per-token noise + block-causal)
|
| 19 |
+
lam.py — Latent Action Model (VQ-VAE codebook)
|
| 20 |
+
train_cdf.py — training loop with diffusion forcing objective
|
| 21 |
+
sample.py — sampling: AR mode / denoise mode / counterfactual
|
| 22 |
+
distill.py — Shortcut Forcing distillation (Dreamer 4)
|
| 23 |
+
prm.py — Process Reward Verifier
|
| 24 |
+
"""
|
| 25 |
+
from .diffusion_forcing import CDFTransformer, CDFConfig
|
| 26 |
+
from .lam import LatentActionVQVAE, LAMConfig
|
| 27 |
+
from .train_cdf import train_cdf
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"CDFTransformer", "CDFConfig",
|
| 31 |
+
"LatentActionVQVAE", "LAMConfig",
|
| 32 |
+
"train_cdf",
|
| 33 |
+
]
|
src/adaln_zero.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
|
| 2 |
+
|
| 3 |
+
Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
|
| 4 |
+
PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
|
| 5 |
+
|
| 6 |
+
Why for Diffusion Forcing on EHR:
|
| 7 |
+
- Per-token sigma + global cond/action → per-token (scale, shift, gate)
|
| 8 |
+
- Gates init to zero ⇒ block starts as identity ⇒ no catastrophic init
|
| 9 |
+
- Much better CFG (dropped condition path goes through zero gates,
|
| 10 |
+
not corrupting residual stream)
|
| 11 |
+
- DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
|
| 12 |
+
|
| 13 |
+
We fuse THREE conditioning signals:
|
| 14 |
+
- sigma (B, T) per-token noise level → time_emb (B, T, D)
|
| 15 |
+
- cond (B,) cohort-level treatment id → cond_emb (B, D) → broadcast
|
| 16 |
+
- action(B, T) per-token latent action id → action_emb (B, T, D)
|
| 17 |
+
|
| 18 |
+
Combined into c_t (B, T, D) → ConditioningMLP → 6 modulation tensors
|
| 19 |
+
per block. Each block uses them as:
|
| 20 |
+
|
| 21 |
+
h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
|
| 22 |
+
h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AdaLNZeroModulator(nn.Module):
|
| 30 |
+
"""Generates per-token (scale, shift, gate) for AdaLN-Zero block.
|
| 31 |
+
|
| 32 |
+
Input: fused conditioning vector c (B, T, d_model).
|
| 33 |
+
Output: 6 tensors of shape (B, T, d_model) each:
|
| 34 |
+
(scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, d_model: int):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.modulator = nn.Sequential(
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(d_model, 6 * d_model, bias=True),
|
| 41 |
+
)
|
| 42 |
+
# Zero-init for the gate-producing rows (AdaLN-Zero trick)
|
| 43 |
+
# We zero-init ALL outputs initially; gate stays zero so block is identity
|
| 44 |
+
nn.init.zeros_(self.modulator[-1].weight)
|
| 45 |
+
nn.init.zeros_(self.modulator[-1].bias)
|
| 46 |
+
|
| 47 |
+
def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
| 48 |
+
# c: (B, T, d_model)
|
| 49 |
+
out = self.modulator(c) # (B, T, 6*d_model)
|
| 50 |
+
return out.chunk(6, dim=-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AdaLNZeroBlock(nn.Module):
|
| 54 |
+
"""Transformer block with AdaLN-Zero modulation.
|
| 55 |
+
|
| 56 |
+
Drop-in replacement for the standard pre-norm block. Reads
|
| 57 |
+
pre-computed modulation tensors and applies them around Attn + MLP.
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
|
| 60 |
+
rope=None, kg_xattn=None):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.d_model = d_model
|
| 63 |
+
self.n_heads = n_heads
|
| 64 |
+
self.head_dim = d_model // n_heads
|
| 65 |
+
self.rope = rope
|
| 66 |
+
self.kg_xattn = kg_xattn
|
| 67 |
+
|
| 68 |
+
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 69 |
+
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 70 |
+
self.proj = nn.Linear(d_model, d_model, bias=False)
|
| 71 |
+
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 72 |
+
self.mlp = nn.Sequential(
|
| 73 |
+
nn.Linear(d_model, ffn, bias=False),
|
| 74 |
+
nn.GELU(),
|
| 75 |
+
nn.Linear(ffn, d_model, bias=False),
|
| 76 |
+
)
|
| 77 |
+
self.dropout = nn.Dropout(dropout)
|
| 78 |
+
|
| 79 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
|
| 80 |
+
scale_msa, shift_msa, gate_msa,
|
| 81 |
+
scale_mlp, shift_mlp, gate_mlp,
|
| 82 |
+
kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
|
| 83 |
+
import torch.nn.functional as F
|
| 84 |
+
B, T, D = x.shape
|
| 85 |
+
# MSA branch
|
| 86 |
+
h = self.norm1(x) * (1 + scale_msa) + shift_msa
|
| 87 |
+
qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
|
| 88 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 89 |
+
if self.rope is not None:
|
| 90 |
+
q, k = self.rope(q, k, T)
|
| 91 |
+
out = F.scaled_dot_product_attention(
|
| 92 |
+
q, k, v,
|
| 93 |
+
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
|
| 94 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 95 |
+
)
|
| 96 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 97 |
+
x = x + gate_msa * self.dropout(self.proj(out))
|
| 98 |
+
# KG cross-attention (between MSA and MLP)
|
| 99 |
+
if self.kg_xattn is not None and kg_ctx is not None:
|
| 100 |
+
x = self.kg_xattn(x, kg_ctx)
|
| 101 |
+
# MLP branch
|
| 102 |
+
h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
|
| 103 |
+
x = x + gate_mlp * self.dropout(self.mlp(h))
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FusedConditioner(nn.Module):
|
| 108 |
+
"""Fuse (sigma, cond, action) into one per-token conditioning vector.
|
| 109 |
+
|
| 110 |
+
Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, d_model: int, n_conditions: int, n_actions: int,
|
| 113 |
+
use_action: bool = True):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.d_model = d_model
|
| 116 |
+
self.use_action = use_action
|
| 117 |
+
# Sigma → sinusoidal embedding
|
| 118 |
+
self.sigma_proj = nn.Sequential(
|
| 119 |
+
nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
|
| 120 |
+
)
|
| 121 |
+
self.cond_emb = nn.Embedding(n_conditions, d_model)
|
| 122 |
+
if use_action:
|
| 123 |
+
self.action_emb = nn.Embedding(n_actions + 1, d_model)
|
| 124 |
+
self.fuse = nn.Sequential(
|
| 125 |
+
nn.SiLU(),
|
| 126 |
+
nn.Linear(d_model, d_model),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
import math
|
| 131 |
+
half = self.d_model // 2
|
| 132 |
+
freqs = torch.exp(
|
| 133 |
+
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
|
| 134 |
+
)
|
| 135 |
+
ang = sigma.float().unsqueeze(-1) * freqs
|
| 136 |
+
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
|
| 137 |
+
return self.sigma_proj(emb)
|
| 138 |
+
|
| 139 |
+
def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
|
| 140 |
+
action: torch.Tensor | None = None) -> torch.Tensor:
|
| 141 |
+
# sigma (B, T) → time_emb (B, T, D)
|
| 142 |
+
time_emb = self.sinusoidal(sigma)
|
| 143 |
+
# cond (B,) → (B, D) → broadcast to (B, T, D)
|
| 144 |
+
cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
|
| 145 |
+
fused = time_emb + cond_emb
|
| 146 |
+
if self.use_action and action is not None:
|
| 147 |
+
fused = fused + self.action_emb(action)
|
| 148 |
+
return self.fuse(fused)
|
src/diffusion_forcing_v13.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.
|
| 2 |
+
|
| 3 |
+
Per the SOTA audit (May 2026):
|
| 4 |
+
- Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
|
| 5 |
+
- Path A adopted: small from-scratch model + KG adapters + MEDS interop
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
- 12M backbone params (Chinchilla-respecting for ~20M token corpus)
|
| 9 |
+
- d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
|
| 10 |
+
- SwiGLU MLP (ffn:d_model = 2.67)
|
| 11 |
+
- Tied embeddings (saves ~12M at vocab=32k)
|
| 12 |
+
- Dropout 0.1 everywhere (small-data critical)
|
| 13 |
+
- Block-causal attention (Diffusion Forcing)
|
| 14 |
+
- Per-token sigma noise (independent)
|
| 15 |
+
- GATED KG cross-attention (tanh(α)·xattn, α init=0)
|
| 16 |
+
- Layers 4, 6, 7 (3 of 8)
|
| 17 |
+
- Lets model learn to use KG progressively, doesn't disrupt early loss
|
| 18 |
+
- DF objective + LM-aux loss (joint training, paper-grade)
|
| 19 |
+
|
| 20 |
+
Sources audited:
|
| 21 |
+
- CoMET (Aug 2025): tokens-per-param ratio
|
| 22 |
+
- CLMBR (Stanford): adapter pattern for cross-site transfer
|
| 23 |
+
- MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
|
| 24 |
+
- Genie (DeepMind 2024): gated cross-attention pattern
|
| 25 |
+
- SD3 (Esser 2024): AdaLN-Zero zero-init gates
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
import math
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class CDFv13Config:
|
| 38 |
+
# Vocab + sequence
|
| 39 |
+
vocab_size: int = 32768 # MEDS-derived (will be much smaller in practice)
|
| 40 |
+
mask_token: int = 32767
|
| 41 |
+
max_seq_len: int = 512
|
| 42 |
+
block_size: int = 16
|
| 43 |
+
# Architecture (Chinchilla-correct for ~20M tokens)
|
| 44 |
+
d_model: int = 384
|
| 45 |
+
n_heads: int = 6
|
| 46 |
+
n_layers: int = 8
|
| 47 |
+
ffn: int = 1024 # SwiGLU effective; flag below uses 2 projections
|
| 48 |
+
dropout: float = 0.1
|
| 49 |
+
emb_dropout: float = 0.1
|
| 50 |
+
use_swiglu: bool = True
|
| 51 |
+
use_rmsnorm: bool = True
|
| 52 |
+
tie_embeddings: bool = True
|
| 53 |
+
# Diffusion forcing
|
| 54 |
+
cond_dropout: float = 0.10
|
| 55 |
+
# KG conditioning (GATED adapters)
|
| 56 |
+
use_kg: bool = True
|
| 57 |
+
kg_dim: int = 3072
|
| 58 |
+
kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
|
| 59 |
+
# Latent action
|
| 60 |
+
use_latent_action: bool = False # Dropped per audit (concept shaky)
|
| 61 |
+
n_latent_actions: int = 512
|
| 62 |
+
# Conditioning
|
| 63 |
+
n_conditions: int = 64
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class RMSNorm(nn.Module):
|
| 67 |
+
"""Root-mean-square LayerNorm (LLaMA/Mistral style)."""
|
| 68 |
+
def __init__(self, d: int, eps: float = 1e-6):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.weight = nn.Parameter(torch.ones(d))
|
| 71 |
+
self.eps = eps
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
| 75 |
+
return (norm * self.weight.float()).to(x.dtype)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SwiGLU(nn.Module):
|
| 79 |
+
"""SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
|
| 80 |
+
def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
|
| 83 |
+
self.w_up = nn.Linear(d_in, d_hidden, bias=False)
|
| 84 |
+
self.w_down = nn.Linear(d_hidden, d_in, bias=False)
|
| 85 |
+
self.dropout = nn.Dropout(dropout)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class RotaryEmbedding(nn.Module):
|
| 92 |
+
"""RoPE (Su et al. 2021)."""
|
| 93 |
+
def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
|
| 94 |
+
super().__init__()
|
| 95 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 96 |
+
t = torch.arange(max_seq).float()
|
| 97 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 98 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 99 |
+
self.register_buffer("cos", emb.cos(), persistent=False)
|
| 100 |
+
self.register_buffer("sin", emb.sin(), persistent=False)
|
| 101 |
+
|
| 102 |
+
def forward(self, q, k, seq_len):
|
| 103 |
+
cos = self.cos[:seq_len].to(q.dtype).to(q.device)
|
| 104 |
+
sin = self.sin[:seq_len].to(q.dtype).to(q.device)
|
| 105 |
+
def rot_half(x):
|
| 106 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 107 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 108 |
+
return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class PerTokenSigmaEmbed(nn.Module):
|
| 112 |
+
"""Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
|
| 113 |
+
def __init__(self, d: int):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.d = d
|
| 116 |
+
self.proj = nn.Sequential(
|
| 117 |
+
nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def forward(self, sigma: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
half = self.d // 2
|
| 122 |
+
freqs = torch.exp(
|
| 123 |
+
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
|
| 124 |
+
)
|
| 125 |
+
ang = sigma.float().unsqueeze(-1) * freqs
|
| 126 |
+
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
|
| 127 |
+
return self.proj(emb)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class GatedKGCrossAttention(nn.Module):
|
| 131 |
+
"""Cross-attention to KG ego-subgraph, with GATED output.
|
| 132 |
+
|
| 133 |
+
`tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
|
| 134 |
+
initialized to 0. This means at init the cross-attention contributes
|
| 135 |
+
NOTHING to the residual stream, so the model trains identically to
|
| 136 |
+
no-KG until it discovers KG is useful. Prevents catastrophic loss
|
| 137 |
+
spikes on small data.
|
| 138 |
+
|
| 139 |
+
Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
|
| 140 |
+
"""
|
| 141 |
+
def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.n_heads = n_heads
|
| 144 |
+
self.head_dim = d_model // n_heads
|
| 145 |
+
# Project KG to d_model (run inline so we don't need separate KGProjector module)
|
| 146 |
+
self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
|
| 147 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 148 |
+
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
|
| 149 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 150 |
+
self.norm_q = RMSNorm(d_model)
|
| 151 |
+
self.norm_kv = RMSNorm(d_model)
|
| 152 |
+
self.dropout = nn.Dropout(dropout)
|
| 153 |
+
# Gate (scalar per block, init=0)
|
| 154 |
+
self.alpha = nn.Parameter(torch.zeros(1))
|
| 155 |
+
|
| 156 |
+
def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
"""
|
| 158 |
+
x_seq: (B, T, d_model)
|
| 159 |
+
kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072)
|
| 160 |
+
"""
|
| 161 |
+
B, T, D = x_seq.shape
|
| 162 |
+
kg_proj = self.kg_in_proj(kg_raw) # (B, N_kg, D)
|
| 163 |
+
N_kg = kg_proj.size(1)
|
| 164 |
+
q = self.q_proj(self.norm_q(x_seq))
|
| 165 |
+
kv = self.kv_proj(self.norm_kv(kg_proj))
|
| 166 |
+
k, v = kv.chunk(2, dim=-1)
|
| 167 |
+
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 168 |
+
k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
|
| 169 |
+
v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
|
| 170 |
+
out = F.scaled_dot_product_attention(
|
| 171 |
+
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
|
| 172 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 173 |
+
gate = torch.tanh(self.alpha)
|
| 174 |
+
return x_seq + gate * self.dropout(self.out_proj(out))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class CDFv13Block(nn.Module):
|
| 178 |
+
"""Pre-norm transformer block + optional gated KG cross-attn."""
|
| 179 |
+
def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
|
| 180 |
+
layer_idx: int):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.cfg = cfg
|
| 183 |
+
self.rope = rope
|
| 184 |
+
self.layer_idx = layer_idx
|
| 185 |
+
norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
|
| 186 |
+
self.norm1 = norm_cls(cfg.d_model)
|
| 187 |
+
self.norm2 = norm_cls(cfg.d_model)
|
| 188 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 189 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 190 |
+
if cfg.use_swiglu:
|
| 191 |
+
self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
|
| 192 |
+
else:
|
| 193 |
+
self.mlp = nn.Sequential(
|
| 194 |
+
nn.Linear(cfg.d_model, cfg.ffn, bias=False),
|
| 195 |
+
nn.GELU(),
|
| 196 |
+
nn.Linear(cfg.ffn, cfg.d_model, bias=False),
|
| 197 |
+
nn.Dropout(cfg.dropout),
|
| 198 |
+
)
|
| 199 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 200 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 201 |
+
# Gated KG cross-attention (only in specified layers)
|
| 202 |
+
self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
|
| 203 |
+
if self.use_kg_in_layer:
|
| 204 |
+
self.kg_xattn = GatedKGCrossAttention(
|
| 205 |
+
cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
|
| 206 |
+
|
| 207 |
+
def forward(self, x, attn_mask, kg_raw=None):
|
| 208 |
+
B, T, D = x.shape
|
| 209 |
+
# MSA
|
| 210 |
+
h = self.norm1(x)
|
| 211 |
+
qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
|
| 212 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 213 |
+
q, k = self.rope(q, k, T)
|
| 214 |
+
out = F.scaled_dot_product_attention(
|
| 215 |
+
q, k, v,
|
| 216 |
+
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
|
| 217 |
+
dropout_p=self.cfg.dropout if self.training else 0.0,
|
| 218 |
+
)
|
| 219 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 220 |
+
x = x + self.dropout(self.proj(out))
|
| 221 |
+
# Gated KG cross-attn (if enabled at this layer)
|
| 222 |
+
if self.use_kg_in_layer and kg_raw is not None:
|
| 223 |
+
x = self.kg_xattn(x, kg_raw)
|
| 224 |
+
# MLP
|
| 225 |
+
x = x + self.mlp(self.norm2(x))
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class CDFv13Transformer(nn.Module):
|
| 230 |
+
"""Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""
|
| 231 |
+
|
| 232 |
+
def __init__(self, cfg: CDFv13Config | None = None):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.cfg = cfg or CDFv13Config()
|
| 235 |
+
c = self.cfg
|
| 236 |
+
norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm
|
| 237 |
+
|
| 238 |
+
self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
|
| 239 |
+
self.emb_dropout = nn.Dropout(c.emb_dropout)
|
| 240 |
+
|
| 241 |
+
# Per-token sigma embedding (additive)
|
| 242 |
+
self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
|
| 243 |
+
# Global condition embedding (additive, broadcast)
|
| 244 |
+
self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
|
| 245 |
+
|
| 246 |
+
# RoPE
|
| 247 |
+
self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)
|
| 248 |
+
|
| 249 |
+
# Blocks
|
| 250 |
+
self.blocks = nn.ModuleList([
|
| 251 |
+
CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
|
| 252 |
+
])
|
| 253 |
+
self.final_norm = norm_cls(c.d_model)
|
| 254 |
+
self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
|
| 255 |
+
if c.tie_embeddings:
|
| 256 |
+
self.head.weight = self.tok_emb.weight
|
| 257 |
+
|
| 258 |
+
# Block-causal mask buffer
|
| 259 |
+
T = c.max_seq_len
|
| 260 |
+
block_id = torch.arange(T) // c.block_size
|
| 261 |
+
mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
|
| 262 |
+
self.register_buffer("block_mask", mask, persistent=False)
|
| 263 |
+
|
| 264 |
+
# Init
|
| 265 |
+
self.apply(self._init_weights)
|
| 266 |
+
|
| 267 |
+
def _init_weights(self, m):
|
| 268 |
+
if isinstance(m, nn.Linear):
|
| 269 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 270 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 271 |
+
elif isinstance(m, nn.Embedding):
|
| 272 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 273 |
+
|
| 274 |
+
def forward(self, x, sigma, cond, kg_raw=None):
|
| 275 |
+
B, T = x.shape
|
| 276 |
+
h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
|
| 277 |
+
h = self.emb_dropout(h)
|
| 278 |
+
mask = self.block_mask[:T, :T]
|
| 279 |
+
for blk in self.blocks:
|
| 280 |
+
h = blk(h, mask, kg_raw=kg_raw)
|
| 281 |
+
h = self.final_norm(h)
|
| 282 |
+
return self.head(h)
|
| 283 |
+
|
| 284 |
+
def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
|
| 285 |
+
mode: str = "uniform") -> torch.Tensor:
|
| 286 |
+
"""Standard absorbing-state DF loss with per-token sigma.
|
| 287 |
+
|
| 288 |
+
mode: 'uniform' (default — safer for discrete than logit-normal per audit)
|
| 289 |
+
'logit_normal' (SD3-style — keep as ablation only)
|
| 290 |
+
"""
|
| 291 |
+
B, T = x_clean.shape
|
| 292 |
+
device = x_clean.device
|
| 293 |
+
# CFG cond dropout
|
| 294 |
+
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
|
| 295 |
+
cond = torch.where(drop, torch.zeros_like(cond), cond)
|
| 296 |
+
if kg_raw is not None:
|
| 297 |
+
drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
|
| 298 |
+
kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
|
| 299 |
+
# Sample per-token sigma
|
| 300 |
+
if mode == "logit_normal":
|
| 301 |
+
sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
|
| 302 |
+
else:
|
| 303 |
+
sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
|
| 304 |
+
# Absorbing-state corruption
|
| 305 |
+
corrupt = torch.rand(B, T, device=device) < sigma
|
| 306 |
+
x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
|
| 307 |
+
logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
|
| 308 |
+
ce = F.cross_entropy(
|
| 309 |
+
logits.reshape(-1, self.cfg.vocab_size),
|
| 310 |
+
x_clean.reshape(-1),
|
| 311 |
+
reduction="none",
|
| 312 |
+
).reshape(B, T)
|
| 313 |
+
n = corrupt.float().sum().clamp(min=1.0)
|
| 314 |
+
return (ce * corrupt.float()).sum() / n
|
src/eval_sota.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SOTA evaluation suite for CDFv13 — audit-proof.
|
| 2 |
+
|
| 3 |
+
Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
|
| 4 |
+
with the canonical EHR foundation model metric stack:
|
| 5 |
+
|
| 6 |
+
Classification (next-event, downstream tasks):
|
| 7 |
+
- AUROC + AUPRC + Brier
|
| 8 |
+
- Calibration: ICI (Austin & Steyerberg 2019)
|
| 9 |
+
- Decision-curve analysis (Vickers)
|
| 10 |
+
- Bootstrap 95% CI (≥2000 resamples) — required for rare disease
|
| 11 |
+
|
| 12 |
+
Survival (DATASUS SIM mortality):
|
| 13 |
+
- Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
|
| 14 |
+
- Integrated Brier Score (1/3/5y)
|
| 15 |
+
- Time-dependent AUC
|
| 16 |
+
|
| 17 |
+
Counterfactual / causal:
|
| 18 |
+
- ATE with bootstrap CI
|
| 19 |
+
- E-value (VanderWeele)
|
| 20 |
+
- Negative-control outcome + exposure
|
| 21 |
+
- Tipping-point analysis
|
| 22 |
+
|
| 23 |
+
Generation fidelity (CoMET / SynthEHRella):
|
| 24 |
+
- Dim-wise probability match
|
| 25 |
+
- MMD (Maximum Mean Discrepancy) with RBF kernel
|
| 26 |
+
- TSTR (Train-on-Synthetic-Test-on-Real)
|
| 27 |
+
|
| 28 |
+
Subgroup fairness (npj DM requirement):
|
| 29 |
+
- Stratified metrics: sex, age band, UF region
|
| 30 |
+
|
| 31 |
+
Split strategy (DATASUS rare disease):
|
| 32 |
+
- Temporal: train ≤2022, val 2023, test 2024-2025
|
| 33 |
+
- Geographic: train SE+S, test N+NE (UF cross-region = "external")
|
| 34 |
+
- Patient-level 5-fold CV (variance estimation)
|
| 35 |
+
"""
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
import math
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
from typing import Callable
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------- Classification ----------
|
| 44 |
+
|
| 45 |
+
def auroc(y: np.ndarray, p: np.ndarray) -> float:
|
| 46 |
+
from sklearn.metrics import roc_auc_score
|
| 47 |
+
if len(np.unique(y)) < 2: return float("nan")
|
| 48 |
+
return roc_auc_score(y, p)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def auprc(y: np.ndarray, p: np.ndarray) -> float:
|
| 52 |
+
from sklearn.metrics import average_precision_score
|
| 53 |
+
if len(np.unique(y)) < 2: return float("nan")
|
| 54 |
+
return average_precision_score(y, p)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def brier(y: np.ndarray, p: np.ndarray) -> float:
|
| 58 |
+
from sklearn.metrics import brier_score_loss
|
| 59 |
+
return brier_score_loss(y, p)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
|
| 63 |
+
"""Integrated Calibration Index (Austin & Steyerberg 2019).
|
| 64 |
+
Lowess-smoothed deviation from perfect calibration.
|
| 65 |
+
"""
|
| 66 |
+
from statsmodels.nonparametric.smoothers_lowess import lowess
|
| 67 |
+
sm = lowess(y, p, frac=frac, return_sorted=True)
|
| 68 |
+
return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
|
| 72 |
+
"""Net benefit at a given decision threshold (Vickers DCA)."""
|
| 73 |
+
tp = ((p >= threshold) & (y == 1)).sum()
|
| 74 |
+
fp = ((p >= threshold) & (y == 0)).sum()
|
| 75 |
+
n = len(y)
|
| 76 |
+
if threshold >= 1.0: return 0.0
|
| 77 |
+
return tp / n - (fp / n) * (threshold / (1 - threshold))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def decision_curve(y: np.ndarray, p: np.ndarray,
|
| 81 |
+
thresholds: list[float] = None) -> dict:
|
| 82 |
+
"""Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
|
| 83 |
+
if thresholds is None:
|
| 84 |
+
thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
|
| 85 |
+
model_nb = [net_benefit(y, p, t) for t in thresholds]
|
| 86 |
+
treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
|
| 87 |
+
for t in thresholds]
|
| 88 |
+
treat_none_nb = [0.0] * len(thresholds)
|
| 89 |
+
return {
|
| 90 |
+
"thresholds": thresholds,
|
| 91 |
+
"model": model_nb,
|
| 92 |
+
"treat_all": treat_all_nb,
|
| 93 |
+
"treat_none": treat_none_nb,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
|
| 98 |
+
n_boot: int = 2000, seed: int = 0,
|
| 99 |
+
ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
|
| 100 |
+
"""Bootstrap 95% CI for any (y, p) -> scalar metric."""
|
| 101 |
+
rng = np.random.default_rng(seed)
|
| 102 |
+
n = len(y)
|
| 103 |
+
stats = []
|
| 104 |
+
for _ in range(n_boot):
|
| 105 |
+
idx = rng.integers(0, n, n)
|
| 106 |
+
if len(np.unique(y[idx])) < 2: continue
|
| 107 |
+
try:
|
| 108 |
+
stats.append(metric_fn(y[idx], p[idx]))
|
| 109 |
+
except Exception:
|
| 110 |
+
continue
|
| 111 |
+
if not stats: return (float("nan"),) * 3
|
| 112 |
+
return (
|
| 113 |
+
float(np.percentile(stats, ci[0])),
|
| 114 |
+
float(np.median(stats)),
|
| 115 |
+
float(np.percentile(stats, ci[1])),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ---------- Survival ----------
|
| 120 |
+
|
| 121 |
+
def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
|
| 122 |
+
risk_score, tau: float = None) -> float:
|
| 123 |
+
"""Uno's C-index (IPCW concordance), preferred at high censoring.
|
| 124 |
+
Requires scikit-survival.
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
from sksurv.metrics import concordance_index_ipcw
|
| 128 |
+
except ImportError:
|
| 129 |
+
return float("nan")
|
| 130 |
+
# Build structured arrays
|
| 131 |
+
y_train = np.array(
|
| 132 |
+
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
|
| 133 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 134 |
+
)
|
| 135 |
+
y_test = np.array(
|
| 136 |
+
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
|
| 137 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 138 |
+
)
|
| 139 |
+
if tau is None:
|
| 140 |
+
tau = float(y_test_time.max()) * 0.95
|
| 141 |
+
c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
|
| 142 |
+
return float(c)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
|
| 146 |
+
surv_pred: np.ndarray, times: np.ndarray) -> float:
|
| 147 |
+
"""Integrated Brier Score (lower is better)."""
|
| 148 |
+
try:
|
| 149 |
+
from sksurv.metrics import integrated_brier_score as ibs_fn
|
| 150 |
+
except ImportError:
|
| 151 |
+
return float("nan")
|
| 152 |
+
y_train = np.array(
|
| 153 |
+
list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
|
| 154 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 155 |
+
)
|
| 156 |
+
y_test = np.array(
|
| 157 |
+
list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
|
| 158 |
+
dtype=[("event", "?"), ("time", "<f8")],
|
| 159 |
+
)
|
| 160 |
+
return float(ibs_fn(y_train, y_test, surv_pred, times))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------- Causal / Counterfactual ----------
|
| 164 |
+
|
| 165 |
+
def e_value(rr: float) -> float:
|
| 166 |
+
"""E-value (VanderWeele & Ding 2017): min strength of unmeasured
|
| 167 |
+
confounder needed to explain away an observed RR.
|
| 168 |
+
"""
|
| 169 |
+
rr = max(rr, 1e-9)
|
| 170 |
+
if rr >= 1.0:
|
| 171 |
+
return rr + math.sqrt(rr * (rr - 1))
|
| 172 |
+
rr_inv = 1.0 / rr
|
| 173 |
+
return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
|
| 177 |
+
"""Negative-control outcome: ATE on a control outcome should be ~0."""
|
| 178 |
+
return abs(nc_ate) < threshold
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def tipping_point(observed_effect: float, ci_half_width: float) -> float:
|
| 182 |
+
"""How much would unmeasured confounding need to shift effect to nullify?"""
|
| 183 |
+
if abs(observed_effect) <= ci_half_width:
|
| 184 |
+
return 0.0
|
| 185 |
+
return float(abs(observed_effect) - ci_half_width)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ---------- Generation fidelity (SynthEHRella triad) ----------
|
| 189 |
+
|
| 190 |
+
def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
|
| 191 |
+
vocab_size: int) -> float:
|
| 192 |
+
"""Compare per-token Bernoulli rates between real and synthetic batches.
|
| 193 |
+
|
| 194 |
+
Returns mean abs difference (lower = closer match).
|
| 195 |
+
"""
|
| 196 |
+
real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
|
| 197 |
+
synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
|
| 198 |
+
return float((real_one_hot - synth_one_hot).abs().mean())
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
|
| 202 |
+
"""Maximum Mean Discrepancy with RBF kernel.
|
| 203 |
+
|
| 204 |
+
x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
|
| 205 |
+
"""
|
| 206 |
+
def rbf(a, b):
|
| 207 |
+
d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
|
| 208 |
+
return torch.exp(-d / (2 * sigma ** 2))
|
| 209 |
+
return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ---------- Subgroup fairness ----------
|
| 213 |
+
|
| 214 |
+
def stratified_metrics(y: np.ndarray, p: np.ndarray,
|
| 215 |
+
groups: np.ndarray,
|
| 216 |
+
metric_fn: Callable = auroc) -> dict[str, float]:
|
| 217 |
+
"""Compute metric per subgroup (sex, age band, UF region)."""
|
| 218 |
+
out = {}
|
| 219 |
+
for g in np.unique(groups):
|
| 220 |
+
mask = groups == g
|
| 221 |
+
if mask.sum() > 10:
|
| 222 |
+
try:
|
| 223 |
+
out[str(g)] = metric_fn(y[mask], p[mask])
|
| 224 |
+
except Exception:
|
| 225 |
+
out[str(g)] = float("nan")
|
| 226 |
+
return out
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ---------- DATASUS split strategies ----------
|
| 230 |
+
|
| 231 |
+
def temporal_split(events: list[dict], train_until: int = 2022,
|
| 232 |
+
val_year: int = 2023):
|
| 233 |
+
"""Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
|
| 234 |
+
train, val, test = [], [], []
|
| 235 |
+
for e in events:
|
| 236 |
+
y = e.get("year") or 2020
|
| 237 |
+
if y <= train_until: train.append(e)
|
| 238 |
+
elif y == val_year: val.append(e)
|
| 239 |
+
else: test.append(e)
|
| 240 |
+
return train, val, test
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def geographic_split(patients: list[dict], external_ufs: set = None):
|
| 244 |
+
"""Geographic split: train on SE+S, test on N+NE.
|
| 245 |
+
For DATASUS this is the closest analog to "external validation."
|
| 246 |
+
"""
|
| 247 |
+
if external_ufs is None:
|
| 248 |
+
external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
|
| 249 |
+
"PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
|
| 250 |
+
train, test = [], []
|
| 251 |
+
for p in patients:
|
| 252 |
+
uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
|
| 253 |
+
None)
|
| 254 |
+
(test if uf in external_ufs else train).append(p)
|
| 255 |
+
return train, test
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------- Combined eval report ----------
|
| 259 |
+
|
| 260 |
+
def full_eval_report(y: np.ndarray, p: np.ndarray,
|
| 261 |
+
groups_sex: np.ndarray = None,
|
| 262 |
+
groups_age: np.ndarray = None,
|
| 263 |
+
groups_uf: np.ndarray = None,
|
| 264 |
+
n_boot: int = 2000) -> dict:
|
| 265 |
+
"""Generate a full audit-proof report for a binary classification task.
|
| 266 |
+
|
| 267 |
+
Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
|
| 268 |
+
"""
|
| 269 |
+
import torch.nn.functional as F # local import to keep top clean
|
| 270 |
+
|
| 271 |
+
auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
|
| 272 |
+
auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
|
| 273 |
+
brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)
|
| 274 |
+
|
| 275 |
+
report = {
|
| 276 |
+
"n_eval": len(y),
|
| 277 |
+
"prevalence": float(y.mean()),
|
| 278 |
+
"auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
|
| 279 |
+
"auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
|
| 280 |
+
"brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
|
| 281 |
+
"ici": ici(y, p),
|
| 282 |
+
"decision_curve": decision_curve(y, p),
|
| 283 |
+
}
|
| 284 |
+
if groups_sex is not None:
|
| 285 |
+
report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
|
| 286 |
+
if groups_age is not None:
|
| 287 |
+
report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
|
| 288 |
+
if groups_uf is not None:
|
| 289 |
+
report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
|
| 290 |
+
return report
|
src/meds_export.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MEDS v0.4.1 exporter for DATASUS — audit-proof, interop-ready.
|
| 2 |
+
|
| 3 |
+
Verified against:
|
| 4 |
+
- meds 0.4.1 schemas (DataSchema, CodeMetadataSchema)
|
| 5 |
+
- https://github.com/Medical-Event-Data-Standard/meds
|
| 6 |
+
- CLMBR/MOTOR/EHRSHOT/CoMET tokenization conventions
|
| 7 |
+
|
| 8 |
+
Code conventions (interop-compatible):
|
| 9 |
+
- Static (time=None): GENDER//, RACE//, UF//, MUN//, ORPHA//
|
| 10 |
+
- Birth/Death: MEDS_BIRTH, MEDS_DEATH (reserved)
|
| 11 |
+
- Diagnoses: ICD10//<cid> (NOT CID10// — interop with OHDSI/Athena)
|
| 12 |
+
- Hospitalization: SIH//ADM, SIH//DIS (numeric_value=LOS_days on DIS)
|
| 13 |
+
- Procedures: SIGTAP//<10-digit> (Brazil-local namespace)
|
| 14 |
+
- Drugs (APAC): APAC//<sigtap> (numeric_value=monthly_cost_brl)
|
| 15 |
+
- Outpatient (BPA-I): BPAI//<sigtap>
|
| 16 |
+
- Visits: Visit//{IP, OP, ER} (matches CLMBR convention)
|
| 17 |
+
|
| 18 |
+
Outputs canonical MEDS dataset:
|
| 19 |
+
/out/
|
| 20 |
+
├── data/ # parquet shards by subject
|
| 21 |
+
│ ├── shard_0.parquet
|
| 22 |
+
│ └── ...
|
| 23 |
+
├── metadata/
|
| 24 |
+
│ └── codes.parquet # REQUIRED: every unique code with description + parent_codes
|
| 25 |
+
└── dataset_metadata.json # MEDS dataset metadata
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
import os
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
from collections import defaultdict, Counter
|
| 32 |
+
from datetime import datetime
|
| 33 |
+
from typing import Iterator
|
| 34 |
+
|
| 35 |
+
import pyarrow as pa
|
| 36 |
+
import pyarrow.parquet as pq
|
| 37 |
+
import meds
|
| 38 |
+
|
| 39 |
+
log = logging.getLogger("gemeo.cdf.meds_export")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _parse_date(s) -> datetime | None:
|
| 43 |
+
"""Parse date string from various DATASUS formats."""
|
| 44 |
+
if s is None: return None
|
| 45 |
+
s = str(s).strip()
|
| 46 |
+
if not s or s in ("0", "None", "nan"): return None
|
| 47 |
+
try:
|
| 48 |
+
if "-" in s:
|
| 49 |
+
return datetime.strptime(s[:10], "%Y-%m-%d")
|
| 50 |
+
if len(s) == 8:
|
| 51 |
+
return datetime.strptime(s, "%Y%m%d")
|
| 52 |
+
except ValueError:
|
| 53 |
+
return None
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _ym(year, month) -> datetime | None:
|
| 58 |
+
if year is None: return None
|
| 59 |
+
try:
|
| 60 |
+
return datetime(int(year), int(month) if month else 1, 1)
|
| 61 |
+
except (ValueError, TypeError):
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def datasus_patient_to_meds_rows(p: dict, subject_id: int) -> list[tuple]:
|
| 66 |
+
"""Convert one DATASUS patient trajectory to a list of MEDS rows.
|
| 67 |
+
|
| 68 |
+
Each row is (subject_id, time, code, numeric_value, text_value).
|
| 69 |
+
Returns rows ready to write to a parquet shard.
|
| 70 |
+
"""
|
| 71 |
+
rows = []
|
| 72 |
+
|
| 73 |
+
# ---- Static (time=None) ----
|
| 74 |
+
if p.get("sex"):
|
| 75 |
+
rows.append((subject_id, None, f"GENDER//{p['sex']}", None, None))
|
| 76 |
+
# ORPHA is rare-disease specific (parallel to ICD10)
|
| 77 |
+
for orpha in p.get("orphas", []):
|
| 78 |
+
rows.append((subject_id, None, f"ORPHA//{orpha}", None, None))
|
| 79 |
+
|
| 80 |
+
# ---- Birth (use birth_year as Jan 1) ----
|
| 81 |
+
birth_year = p.get("birth_year")
|
| 82 |
+
birth_dt = datetime(int(birth_year), 1, 1) if birth_year else None
|
| 83 |
+
if birth_dt:
|
| 84 |
+
rows.append((subject_id, birth_dt, "MEDS_BIRTH", None, None))
|
| 85 |
+
|
| 86 |
+
# ---- Events ----
|
| 87 |
+
for e in p.get("events", []):
|
| 88 |
+
et = e.get("type")
|
| 89 |
+
|
| 90 |
+
if et == "admission": # SIH-RD
|
| 91 |
+
t = _ym(e.get("year"), e.get("month")) or _parse_date(e.get("admission_date"))
|
| 92 |
+
if not t: continue
|
| 93 |
+
rows.append((subject_id, t, "SIH//ADM", None, None))
|
| 94 |
+
rows.append((subject_id, t, "Visit//IP", None, None))
|
| 95 |
+
cid = e.get("cid_princ", "")
|
| 96 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 97 |
+
proc = e.get("primary_procedure")
|
| 98 |
+
if proc: rows.append((subject_id, t, f"SIGTAP//{proc[:10]}", None, None))
|
| 99 |
+
los = e.get("los_days")
|
| 100 |
+
disch_dt = _parse_date(e.get("discharge_date")) or t
|
| 101 |
+
if e.get("death_during_stay"):
|
| 102 |
+
rows.append((subject_id, disch_dt, "MEDS_DEATH", None, None))
|
| 103 |
+
else:
|
| 104 |
+
rows.append((subject_id, disch_dt, "SIH//DIS",
|
| 105 |
+
float(los) if los is not None else None, None))
|
| 106 |
+
|
| 107 |
+
elif et == "treatment": # APAC-SIA orphan drug
|
| 108 |
+
t = _ym(e.get("year"), e.get("month"))
|
| 109 |
+
if not t: continue
|
| 110 |
+
cid = e.get("cid", "")
|
| 111 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 112 |
+
proc = e.get("procedure_code", "")[:10]
|
| 113 |
+
if proc:
|
| 114 |
+
cost = e.get("monthly_cost_brl")
|
| 115 |
+
rows.append((subject_id, t, f"APAC//{proc}",
|
| 116 |
+
float(cost) if cost is not None else None, None))
|
| 117 |
+
|
| 118 |
+
elif et == "outpatient_proc": # BPA-I
|
| 119 |
+
t = _parse_date(e.get("auth_date")) or _ym(e.get("year"), e.get("month"))
|
| 120 |
+
if not t: continue
|
| 121 |
+
cid = e.get("cid", "")
|
| 122 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 123 |
+
proc = e.get("procedure_code", "")[:10]
|
| 124 |
+
if proc:
|
| 125 |
+
rows.append((subject_id, t, f"BPAI//{proc}", None, None))
|
| 126 |
+
|
| 127 |
+
elif et == "death": # SIM
|
| 128 |
+
t = _parse_date(e.get("date_of_death")) or _ym(e.get("year"), e.get("month"))
|
| 129 |
+
if not t: continue
|
| 130 |
+
rows.append((subject_id, t, "MEDS_DEATH", None, None))
|
| 131 |
+
cid = (e.get("cause_cid") or e.get("cid_princ") or e.get("cid", ""))
|
| 132 |
+
if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
|
| 133 |
+
|
| 134 |
+
# Sort: nulls first (static), then by time
|
| 135 |
+
rows.sort(key=lambda r: (r[1] is not None, r[1] or datetime(1900, 1, 1)))
|
| 136 |
+
return rows
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def export_to_meds(patients: list[dict], out_dir: str,
|
| 140 |
+
shard_size: int = 5000,
|
| 141 |
+
dataset_name: str = "GEMEO-DATASUS",
|
| 142 |
+
version: str = "v13"):
|
| 143 |
+
"""Export a list of DATASUS patient trajectories to MEDS v0.4.1 format.
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
patients : list of dict
|
| 148 |
+
Each dict must have: patient_id, sex, birth_year, orphas (list),
|
| 149 |
+
events (list of dicts with 'type', 'year', 'month', etc.)
|
| 150 |
+
out_dir : str
|
| 151 |
+
Output directory (will create data/ and metadata/ subdirs)
|
| 152 |
+
shard_size : int
|
| 153 |
+
Number of subjects per parquet shard
|
| 154 |
+
"""
|
| 155 |
+
os.makedirs(f"{out_dir}/data", exist_ok=True)
|
| 156 |
+
os.makedirs(f"{out_dir}/metadata", exist_ok=True)
|
| 157 |
+
|
| 158 |
+
log.info(f"Exporting {len(patients)} patients to MEDS at {out_dir}")
|
| 159 |
+
|
| 160 |
+
# Map patient_id (string hash) → int64 subject_id (MEDS requires int64)
|
| 161 |
+
pid_to_sid = {p["patient_id"]: i for i, p in enumerate(patients)}
|
| 162 |
+
|
| 163 |
+
# ---- Stream rows ----
|
| 164 |
+
all_codes = Counter()
|
| 165 |
+
shard_idx = 0
|
| 166 |
+
shard_rows = []
|
| 167 |
+
n_events = 0
|
| 168 |
+
n_subjects = 0
|
| 169 |
+
|
| 170 |
+
for p in patients:
|
| 171 |
+
sid = pid_to_sid[p["patient_id"]]
|
| 172 |
+
rows = datasus_patient_to_meds_rows(p, sid)
|
| 173 |
+
shard_rows.extend(rows)
|
| 174 |
+
n_events += len(rows)
|
| 175 |
+
n_subjects += 1
|
| 176 |
+
for r in rows:
|
| 177 |
+
all_codes[r[2]] += 1
|
| 178 |
+
# Write shard when full
|
| 179 |
+
if n_subjects % shard_size == 0 and shard_rows:
|
| 180 |
+
_write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
|
| 181 |
+
shard_idx += 1
|
| 182 |
+
shard_rows = []
|
| 183 |
+
|
| 184 |
+
# Write remaining
|
| 185 |
+
if shard_rows:
|
| 186 |
+
_write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
|
| 187 |
+
|
| 188 |
+
log.info(f" wrote {shard_idx + 1} data shards, {n_events} rows, {n_subjects} subjects")
|
| 189 |
+
|
| 190 |
+
# ---- codes.parquet (REQUIRED in MEDS v0.4) ----
|
| 191 |
+
code_rows = []
|
| 192 |
+
for code, count in all_codes.most_common():
|
| 193 |
+
# parent_codes: empty for Brazil-local namespaces; populated for ICD10 -> SNOMED if mapped
|
| 194 |
+
parent_codes = _get_parent_codes(code)
|
| 195 |
+
code_rows.append({
|
| 196 |
+
"code": code,
|
| 197 |
+
"description": _get_description(code, count),
|
| 198 |
+
"parent_codes": parent_codes,
|
| 199 |
+
})
|
| 200 |
+
code_table = pa.Table.from_pylist(code_rows, schema=meds.CodeMetadataSchema.schema())
|
| 201 |
+
pq.write_table(code_table, f"{out_dir}/metadata/codes.parquet")
|
| 202 |
+
log.info(f" wrote metadata/codes.parquet ({len(code_rows)} unique codes)")
|
| 203 |
+
|
| 204 |
+
# ---- dataset_metadata.json ----
|
| 205 |
+
md = {
|
| 206 |
+
"dataset_name": dataset_name,
|
| 207 |
+
"dataset_version": version,
|
| 208 |
+
"etl_name": "gemeo.cdf.meds_export",
|
| 209 |
+
"etl_version": "1.0.0",
|
| 210 |
+
"meds_version": meds.__version__,
|
| 211 |
+
"n_subjects": n_subjects,
|
| 212 |
+
"n_events": n_events,
|
| 213 |
+
"n_unique_codes": len(all_codes),
|
| 214 |
+
"top_codes": dict(all_codes.most_common(30)),
|
| 215 |
+
}
|
| 216 |
+
with open(f"{out_dir}/dataset_metadata.json", "w") as f:
|
| 217 |
+
json.dump(md, f, indent=2, default=str)
|
| 218 |
+
log.info(f" wrote dataset_metadata.json")
|
| 219 |
+
|
| 220 |
+
return md
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _write_shard(rows: list[tuple], path: str):
|
| 224 |
+
"""Write a list of (subject_id, time, code, numeric_value, text_value) to parquet."""
|
| 225 |
+
if not rows: return
|
| 226 |
+
# Build columnar arrays
|
| 227 |
+
subject_id = pa.array([r[0] for r in rows], type=pa.int64())
|
| 228 |
+
time = pa.array([r[1] for r in rows], type=pa.timestamp("us"))
|
| 229 |
+
code = pa.array([r[2] for r in rows], type=pa.string())
|
| 230 |
+
numeric_value = pa.array([r[3] for r in rows], type=pa.float32())
|
| 231 |
+
text_value = pa.array([r[4] for r in rows], type=pa.large_string())
|
| 232 |
+
table = pa.Table.from_arrays(
|
| 233 |
+
[subject_id, time, code, numeric_value, text_value],
|
| 234 |
+
names=["subject_id", "time", "code", "numeric_value", "text_value"],
|
| 235 |
+
)
|
| 236 |
+
# Validate against MEDS schema
|
| 237 |
+
expected_schema = meds.DataSchema.schema()
|
| 238 |
+
# Cast if needed
|
| 239 |
+
table = table.cast(expected_schema, safe=False)
|
| 240 |
+
pq.write_table(table, path, compression="zstd")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Brazilian-specific mapping tables (extend as needed)
|
| 244 |
+
ICD10_CHAPTERS = {
|
| 245 |
+
"A": "Certain infectious and parasitic diseases",
|
| 246 |
+
"B": "Certain infectious and parasitic diseases",
|
| 247 |
+
"C": "Neoplasms",
|
| 248 |
+
"D": "Neoplasms / Diseases of the blood and immune",
|
| 249 |
+
"E": "Endocrine, nutritional and metabolic diseases",
|
| 250 |
+
"F": "Mental, Behavioral and Neurodevelopmental disorders",
|
| 251 |
+
"G": "Diseases of the nervous system",
|
| 252 |
+
"H": "Diseases of the eye / ear",
|
| 253 |
+
"I": "Diseases of the circulatory system",
|
| 254 |
+
"J": "Diseases of the respiratory system",
|
| 255 |
+
"K": "Diseases of the digestive system",
|
| 256 |
+
"L": "Diseases of the skin and subcutaneous tissue",
|
| 257 |
+
"M": "Diseases of the musculoskeletal system",
|
| 258 |
+
"N": "Diseases of the genitourinary system",
|
| 259 |
+
"O": "Pregnancy, childbirth and the puerperium",
|
| 260 |
+
"P": "Certain conditions originating in the perinatal period",
|
| 261 |
+
"Q": "Congenital malformations, deformations and chromosomal abnormalities",
|
| 262 |
+
"R": "Symptoms, signs and abnormal clinical and laboratory findings",
|
| 263 |
+
"S": "Injury, poisoning and certain other consequences of external causes",
|
| 264 |
+
"T": "Injury, poisoning and certain other consequences of external causes",
|
| 265 |
+
"V": "External causes of morbidity",
|
| 266 |
+
"W": "External causes of morbidity",
|
| 267 |
+
"X": "External causes of morbidity",
|
| 268 |
+
"Y": "External causes of morbidity",
|
| 269 |
+
"Z": "Factors influencing health status and contact with health services",
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _get_description(code: str, count: int) -> str:
|
| 274 |
+
"""Generate a brief description for a code (used in codes.parquet)."""
|
| 275 |
+
if code in ("MEDS_BIRTH",): return "Birth event (reserved)"
|
| 276 |
+
if code in ("MEDS_DEATH",): return "Death event (reserved)"
|
| 277 |
+
parts = code.split("//")
|
| 278 |
+
if len(parts) < 2: return f"Unknown code (n={count})"
|
| 279 |
+
domain, val = parts[0], "//".join(parts[1:])
|
| 280 |
+
if domain == "GENDER": return f"Patient sex = {val}"
|
| 281 |
+
if domain == "ORPHA": return f"Orphanet rare disease {val}"
|
| 282 |
+
if domain == "ICD10":
|
| 283 |
+
ch = ICD10_CHAPTERS.get(val[0], "Unknown chapter")
|
| 284 |
+
return f"ICD-10 {val} ({ch})"
|
| 285 |
+
if domain == "SIH": return f"SIH hospitalization {val}"
|
| 286 |
+
if domain == "Visit": return f"Visit type {val}"
|
| 287 |
+
if domain == "SIGTAP": return f"SIGTAP procedure {val}"
|
| 288 |
+
if domain == "APAC": return f"APAC orphan-drug authorization {val}"
|
| 289 |
+
if domain == "BPAI": return f"BPA-I outpatient procedure {val}"
|
| 290 |
+
if domain == "UF": return f"Residence UF {val}"
|
| 291 |
+
return f"{domain} code {val}"
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _get_parent_codes(code: str) -> list[str]:
|
| 295 |
+
"""Return parent codes for ontology hierarchy (currently minimal)."""
|
| 296 |
+
parts = code.split("//")
|
| 297 |
+
if len(parts) < 2: return []
|
| 298 |
+
domain, val = parts[0], "//".join(parts[1:])
|
| 299 |
+
parents = []
|
| 300 |
+
if domain == "ICD10" and len(val) >= 3:
|
| 301 |
+
# ICD-10 chapter as parent
|
| 302 |
+
chapter = val[0]
|
| 303 |
+
if chapter in ICD10_CHAPTERS:
|
| 304 |
+
parents.append(f"ICD10//chapter_{chapter}")
|
| 305 |
+
# 3-char prefix as parent (e.g., E84.0 → E84)
|
| 306 |
+
if "." in val:
|
| 307 |
+
parents.append(f"ICD10//{val.split('.')[0]}")
|
| 308 |
+
elif len(val) > 3:
|
| 309 |
+
parents.append(f"ICD10//{val[:3]}")
|
| 310 |
+
if domain == "SIGTAP" and len(val) >= 4:
|
| 311 |
+
# 4-digit group as parent (SIGTAP 10-digit → 4-digit group)
|
| 312 |
+
parents.append(f"SIGTAP//group_{val[:4]}")
|
| 313 |
+
return parents
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def load_meds_dataset(meds_dir: str) -> dict:
|
| 317 |
+
"""Load a MEDS dataset back from parquet for inspection or downstream processing."""
|
| 318 |
+
import glob
|
| 319 |
+
shards = sorted(glob.glob(f"{meds_dir}/data/*.parquet"))
|
| 320 |
+
tables = [pq.read_table(p) for p in shards]
|
| 321 |
+
data = pa.concat_tables(tables) if tables else None
|
| 322 |
+
codes = pq.read_table(f"{meds_dir}/metadata/codes.parquet")
|
| 323 |
+
md = json.load(open(f"{meds_dir}/dataset_metadata.json"))
|
| 324 |
+
return {"data": data, "codes": codes, "metadata": md}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
# Quick test on real patient data
|
| 329 |
+
logging.basicConfig(level=logging.INFO,
|
| 330 |
+
format="%(asctime)s %(levelname)s %(message)s")
|
| 331 |
+
PATIENTS = "/tmp/datasus_patient_trajectories_v2.json"
|
| 332 |
+
if os.path.exists(PATIENTS):
|
| 333 |
+
patients = json.load(open(PATIENTS))[:50] # 50 patients smoke test
|
| 334 |
+
md = export_to_meds(patients, "/tmp/meds_smoke_test")
|
| 335 |
+
print("\n=== smoke test result ===")
|
| 336 |
+
print(json.dumps(md, indent=2, default=str))
|
src/primekg_attention.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.
|
| 2 |
+
|
| 3 |
+
Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
|
| 4 |
+
- disease → has_phenotype → phenotype (curated phenotype linkage)
|
| 5 |
+
- disease → associated_with → gene (causal gene evidence)
|
| 6 |
+
- gene → interacts_with → gene (PPI network)
|
| 7 |
+
- phenotype → is_a → phenotype (HPO ontology)
|
| 8 |
+
|
| 9 |
+
Ego-subgraph BFS:
|
| 10 |
+
1. Start from disease node (ORPHA → PrimeKG index)
|
| 11 |
+
2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
|
| 12 |
+
3. 1-hop: pull connected genes
|
| 13 |
+
4. 2-hop: gene→gene neighbors (interacting partners)
|
| 14 |
+
5. Concatenate fused embeddings of all selected nodes → cross-attn context
|
| 15 |
+
|
| 16 |
+
Falls back to cosine-similarity if graph not loaded.
|
| 17 |
+
|
| 18 |
+
White-space architecture (May 2026):
|
| 19 |
+
- EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
|
| 20 |
+
- PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
|
| 21 |
+
- We use it for GENERATION (counterfactual trajectory completion)
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
import os
|
| 25 |
+
import json
|
| 26 |
+
import logging
|
| 27 |
+
from functools import lru_cache
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
log = logging.getLogger("gemeo.cdf.kg")
|
| 35 |
+
|
| 36 |
+
# Try raras-app paths first (richer, including hetero_graph edges + node_texts)
|
| 37 |
+
RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
|
| 38 |
+
LOCAL_KG_DIR = os.path.join(
|
| 39 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _kg_path(name: str) -> str:
|
| 43 |
+
"""Prefer raras-app path if available, fall back to local fp16."""
|
| 44 |
+
raras = os.path.join(RARAS_KG_DIR, name)
|
| 45 |
+
if os.path.exists(raras):
|
| 46 |
+
return raras
|
| 47 |
+
local = os.path.join(LOCAL_KG_DIR, name)
|
| 48 |
+
return local if os.path.exists(local) else None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(maxsize=1)
|
| 52 |
+
def load_kg(prefer_raras: bool = True) -> dict | None:
|
| 53 |
+
"""Load PrimeKG: fused embeddings + node ids + edges + texts.
|
| 54 |
+
|
| 55 |
+
Returns dict:
|
| 56 |
+
emb : {kind: torch.Tensor(N, 3072)}
|
| 57 |
+
idx2id : {kind: {pos: id_str}}
|
| 58 |
+
id2idx : {kind: {id_str: pos}}
|
| 59 |
+
edges : {edge_type: {'src': [...], 'dst': [...]}}
|
| 60 |
+
adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed
|
| 61 |
+
texts : {kind: [str, ...]} -- aligned to position
|
| 62 |
+
num_nodes : {kind: int}
|
| 63 |
+
"""
|
| 64 |
+
# Try raras-app full file first, then local fp16
|
| 65 |
+
emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
|
| 66 |
+
if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
|
| 67 |
+
else _kg_path("fused_embeddings_fp16.npz"))
|
| 68 |
+
if not emb_path or not os.path.exists(emb_path):
|
| 69 |
+
log.warning("PrimeKG fused embeddings not found")
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
nids_path = _kg_path("node_ids.json")
|
| 73 |
+
graph_path = _kg_path("hetero_graph.json")
|
| 74 |
+
texts_path = _kg_path("node_texts.json")
|
| 75 |
+
|
| 76 |
+
fz = np.load(emb_path)
|
| 77 |
+
nids = json.load(open(nids_path)) if nids_path else {}
|
| 78 |
+
graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
|
| 79 |
+
texts = json.load(open(texts_path)) if texts_path else {}
|
| 80 |
+
|
| 81 |
+
out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
|
| 82 |
+
"texts": texts, "num_nodes": graph.get("num_nodes", {})}
|
| 83 |
+
|
| 84 |
+
for kind in ("disease", "phenotype", "gene"):
|
| 85 |
+
if kind in fz.files:
|
| 86 |
+
out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
|
| 87 |
+
if kind in nids:
|
| 88 |
+
out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
|
| 89 |
+
out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}
|
| 90 |
+
|
| 91 |
+
# Build adjacency from edges
|
| 92 |
+
for edge_type, edata in graph.get("edges", {}).items():
|
| 93 |
+
adj = {}
|
| 94 |
+
srcs = edata.get("src", []) if isinstance(edata, dict) else []
|
| 95 |
+
dsts = edata.get("dst", []) if isinstance(edata, dict) else []
|
| 96 |
+
for s, d in zip(srcs, dsts):
|
| 97 |
+
adj.setdefault(int(s), []).append(int(d))
|
| 98 |
+
out["adj"][edge_type] = adj
|
| 99 |
+
out["edges"][edge_type] = edata
|
| 100 |
+
|
| 101 |
+
log.info(f" KG loaded from {emb_path}")
|
| 102 |
+
log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, "
|
| 103 |
+
f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
|
| 104 |
+
f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
|
| 105 |
+
log.info(f" edges: {list(out['edges'].keys())}")
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
|
| 110 |
+
k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
|
| 111 |
+
"""BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).
|
| 112 |
+
|
| 113 |
+
Returns concatenated embeddings (N, 3072) where:
|
| 114 |
+
- 1 disease node (the query)
|
| 115 |
+
- up to k_pheno phenotype nodes (direct edges)
|
| 116 |
+
- up to k_gene gene nodes (direct edges)
|
| 117 |
+
- up to k_gene_2hop gene-gene 2-hop neighbors
|
| 118 |
+
|
| 119 |
+
Falls back to cosine similarity if no edges available.
|
| 120 |
+
"""
|
| 121 |
+
if kg is None:
|
| 122 |
+
kg = load_kg()
|
| 123 |
+
if kg is None or "disease" not in kg["emb"]:
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
d_id = kg["id2idx"]["disease"].get(str(orpha_code))
|
| 127 |
+
if d_id is None:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
d_emb = kg["emb"]["disease"][d_id]
|
| 131 |
+
nodes = [d_emb.unsqueeze(0)]
|
| 132 |
+
|
| 133 |
+
# Phenotype neighbors (via disease__has_phenotype__phenotype)
|
| 134 |
+
adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
|
| 135 |
+
pheno_neighbors = adj.get(d_id, [])
|
| 136 |
+
if pheno_neighbors and "phenotype" in kg["emb"]:
|
| 137 |
+
pheno_neighbors = pheno_neighbors[:k_pheno]
|
| 138 |
+
nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
|
| 139 |
+
elif "phenotype" in kg["emb"]:
|
| 140 |
+
# Fallback: cosine similarity
|
| 141 |
+
pool = kg["emb"]["phenotype"]
|
| 142 |
+
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
|
| 143 |
+
top = sim.topk(min(k_pheno, pool.size(0))).indices
|
| 144 |
+
nodes.append(pool[top])
|
| 145 |
+
|
| 146 |
+
# Gene neighbors (via disease__associated_with__gene)
|
| 147 |
+
g_adj = kg["adj"].get("disease__associated_with__gene", {})
|
| 148 |
+
gene_neighbors = g_adj.get(d_id, [])
|
| 149 |
+
if gene_neighbors and "gene" in kg["emb"]:
|
| 150 |
+
gene_neighbors = gene_neighbors[:k_gene]
|
| 151 |
+
nodes.append(kg["emb"]["gene"][gene_neighbors])
|
| 152 |
+
|
| 153 |
+
# 2-hop: gene-gene neighbors of the genes we just pulled
|
| 154 |
+
if k_gene_2hop > 0:
|
| 155 |
+
gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
|
| 156 |
+
seen = set(gene_neighbors)
|
| 157 |
+
second_hop = []
|
| 158 |
+
for g in gene_neighbors:
|
| 159 |
+
for g2 in gg_adj.get(g, []):
|
| 160 |
+
if g2 not in seen:
|
| 161 |
+
second_hop.append(g2)
|
| 162 |
+
seen.add(g2)
|
| 163 |
+
if len(second_hop) >= k_gene_2hop: break
|
| 164 |
+
if len(second_hop) >= k_gene_2hop: break
|
| 165 |
+
if second_hop:
|
| 166 |
+
nodes.append(kg["emb"]["gene"][second_hop])
|
| 167 |
+
elif "gene" in kg["emb"]:
|
| 168 |
+
pool = kg["emb"]["gene"]
|
| 169 |
+
sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
|
| 170 |
+
top = sim.topk(min(k_gene, pool.size(0))).indices
|
| 171 |
+
nodes.append(pool[top])
|
| 172 |
+
|
| 173 |
+
return torch.cat(nodes, dim=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Keep old API name for backward compat
|
| 177 |
+
ego_subgraph = ego_subgraph_real
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class KGCrossAttention(nn.Module):
|
| 181 |
+
"""Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
|
| 182 |
+
def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.n_heads = n_heads
|
| 185 |
+
self.head_dim = d_model // n_heads
|
| 186 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 187 |
+
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
|
| 188 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 189 |
+
self.norm_q = nn.LayerNorm(d_model)
|
| 190 |
+
self.norm_kv = nn.LayerNorm(d_model)
|
| 191 |
+
self.dropout = nn.Dropout(dropout)
|
| 192 |
+
|
| 193 |
+
def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
B, T, D = x_seq.shape
|
| 195 |
+
_, N, _ = x_kg.shape
|
| 196 |
+
q = self.q_proj(self.norm_q(x_seq))
|
| 197 |
+
kv = self.kv_proj(self.norm_kv(x_kg))
|
| 198 |
+
k, v = kv.chunk(2, dim=-1)
|
| 199 |
+
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 200 |
+
k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
|
| 201 |
+
v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
|
| 202 |
+
out = F.scaled_dot_product_attention(
|
| 203 |
+
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
|
| 204 |
+
out = out.transpose(1, 2).reshape(B, T, D)
|
| 205 |
+
return x_seq + self.dropout(self.out_proj(out))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class KGProjector(nn.Module):
|
| 209 |
+
"""Project 3072-d KG embeddings to d_model with LayerNorm."""
|
| 210 |
+
def __init__(self, kg_dim: int, d_model: int):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.proj = nn.Sequential(
|
| 213 |
+
nn.Linear(kg_dim, d_model),
|
| 214 |
+
nn.GELU(),
|
| 215 |
+
nn.LayerNorm(d_model),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 219 |
+
return self.proj(x)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def build_kg_batch(orpha_strings: list[str], d_model: int,
|
| 223 |
+
projector: KGProjector,
|
| 224 |
+
k_pheno: int = 16, k_gene: int = 16,
|
| 225 |
+
k_gene_2hop: int = 0) -> torch.Tensor:
|
| 226 |
+
"""Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.
|
| 227 |
+
|
| 228 |
+
Falls back to zero context for missing ORPHAs.
|
| 229 |
+
"""
|
| 230 |
+
kg = load_kg()
|
| 231 |
+
if kg is None:
|
| 232 |
+
return torch.zeros(len(orpha_strings), 1, d_model,
|
| 233 |
+
device=next(projector.parameters()).device)
|
| 234 |
+
N = 1 + k_pheno + k_gene + k_gene_2hop
|
| 235 |
+
egos = []
|
| 236 |
+
for orpha in orpha_strings:
|
| 237 |
+
e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
|
| 238 |
+
if e is None:
|
| 239 |
+
e = torch.zeros(N, kg["emb"]["disease"].size(-1))
|
| 240 |
+
elif e.size(0) < N:
|
| 241 |
+
pad = torch.zeros(N - e.size(0), e.size(-1))
|
| 242 |
+
e = torch.cat([e, pad], dim=0)
|
| 243 |
+
egos.append(e[:N])
|
| 244 |
+
egos = torch.stack(egos, dim=0)
|
| 245 |
+
return projector(egos.to(next(projector.parameters()).device))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
|
| 249 |
+
k_pheno: int = 16, k_gene: int = 16,
|
| 250 |
+
batch_size: int = 32) -> torch.Tensor:
|
| 251 |
+
"""Pre-compute KG context for an entire dataset in batches.
|
| 252 |
+
|
| 253 |
+
Returns (N_patients, kg_nodes, d_model) tensor on projector device.
|
| 254 |
+
Saves to disk-cacheable format.
|
| 255 |
+
"""
|
| 256 |
+
out = []
|
| 257 |
+
for i in range(0, len(orpha_codes), batch_size):
|
| 258 |
+
batch = orpha_codes[i:i + batch_size]
|
| 259 |
+
ctx = build_kg_batch(batch, projector.proj[0].out_features,
|
| 260 |
+
projector, k_pheno, k_gene)
|
| 261 |
+
out.append(ctx.cpu())
|
| 262 |
+
return torch.cat(out, dim=0)
|
src/sample.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
|
| 2 |
+
|
| 3 |
+
Diffusion Forcing flexibility — the same model handles:
|
| 4 |
+
|
| 5 |
+
AR mode:
|
| 6 |
+
Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
|
| 7 |
+
transformer but with per-token noise control.
|
| 8 |
+
|
| 9 |
+
Denoise mode (bidirectional):
|
| 10 |
+
Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
|
| 11 |
+
|
| 12 |
+
Counterfactual mode (the TTE primitive):
|
| 13 |
+
Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
|
| 14 |
+
generate. Condition on (cohort, intervention_action_id). Sample N times,
|
| 15 |
+
compare distributions of outcome tokens.
|
| 16 |
+
|
| 17 |
+
CFG (classifier-free guidance) wraps any mode:
|
| 18 |
+
logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
|
| 19 |
+
|
| 20 |
+
Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
|
| 21 |
+
distilled student model — implemented in distill.py.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
from .diffusion_forcing import CDFTransformer
|
| 30 |
+
|
| 31 |
+
log = logging.getLogger("gemeo.cdf.sample")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def sample_denoise(
|
| 36 |
+
model: CDFTransformer,
|
| 37 |
+
cond: torch.Tensor,
|
| 38 |
+
*,
|
| 39 |
+
seed_prefix: torch.Tensor | None = None,
|
| 40 |
+
observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
|
| 41 |
+
action: torch.Tensor | None = None,
|
| 42 |
+
gamma: float = 2.0,
|
| 43 |
+
n_steps: int = 32,
|
| 44 |
+
null_cond: int = 0,
|
| 45 |
+
schedule: str = "cosine",
|
| 46 |
+
) -> torch.Tensor:
|
| 47 |
+
"""Denoise-mode sampling: fully-masked sequence + iterative refinement.
|
| 48 |
+
|
| 49 |
+
Supports:
|
| 50 |
+
- seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
|
| 51 |
+
- observed_mask: arbitrary positions to clamp (counterfactual mode)
|
| 52 |
+
- CFG via (cond, null_cond) pair
|
| 53 |
+
"""
|
| 54 |
+
cfg = model.cfg
|
| 55 |
+
device = cond.device
|
| 56 |
+
B = cond.size(0)
|
| 57 |
+
T = cfg.max_seq_len
|
| 58 |
+
|
| 59 |
+
# Init with MASK
|
| 60 |
+
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
|
| 61 |
+
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
|
| 62 |
+
if seed_prefix is not None:
|
| 63 |
+
L = seed_prefix.size(1)
|
| 64 |
+
x[:, :L] = seed_prefix
|
| 65 |
+
fixed_mask[:, :L] = True
|
| 66 |
+
if observed_mask is not None:
|
| 67 |
+
fixed_mask |= observed_mask
|
| 68 |
+
|
| 69 |
+
# Build noise schedule
|
| 70 |
+
if schedule == "cosine":
|
| 71 |
+
# smooth cosine from 1 -> 0
|
| 72 |
+
ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
|
| 73 |
+
else:
|
| 74 |
+
ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
|
| 75 |
+
|
| 76 |
+
null = torch.full_like(cond, null_cond)
|
| 77 |
+
null_action = (torch.full_like(action, cfg.n_latent_actions)
|
| 78 |
+
if action is not None and cfg.use_latent_action else None)
|
| 79 |
+
|
| 80 |
+
for k in range(n_steps):
|
| 81 |
+
# Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
|
| 82 |
+
sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
|
| 83 |
+
torch.full((B, T), ts[k].item(), device=device))
|
| 84 |
+
logits_c = model(x, sigma, cond, action)
|
| 85 |
+
if gamma > 0:
|
| 86 |
+
logits_n = model(x, sigma, null, null_action)
|
| 87 |
+
logits = (1 + gamma) * logits_c - gamma * logits_n
|
| 88 |
+
else:
|
| 89 |
+
logits = logits_c
|
| 90 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 91 |
+
|
| 92 |
+
probs = F.softmax(logits, dim=-1)
|
| 93 |
+
confs, preds = probs.max(dim=-1)
|
| 94 |
+
|
| 95 |
+
# Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
|
| 96 |
+
t_next = ts[k+1].item()
|
| 97 |
+
target_kept = int(round((1 - t_next) * T))
|
| 98 |
+
revealed = (x != cfg.mask_token) | fixed_mask
|
| 99 |
+
already = revealed.sum(dim=-1)
|
| 100 |
+
new_x = x.clone()
|
| 101 |
+
for b in range(B):
|
| 102 |
+
need = max(0, target_kept - int(already[b].item()))
|
| 103 |
+
if need == 0:
|
| 104 |
+
continue
|
| 105 |
+
confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
|
| 106 |
+
topi = confs_b.topk(need).indices
|
| 107 |
+
new_x[b, topi] = preds[b, topi]
|
| 108 |
+
x = new_x
|
| 109 |
+
|
| 110 |
+
# Final cleanup
|
| 111 |
+
mask_left = x == cfg.mask_token
|
| 112 |
+
if mask_left.any():
|
| 113 |
+
sigma_final = torch.zeros(B, T, device=device)
|
| 114 |
+
logits_c = model(x, sigma_final, cond, action)
|
| 115 |
+
if gamma > 0:
|
| 116 |
+
logits_n = model(x, sigma_final, null, null_action)
|
| 117 |
+
logits = (1 + gamma) * logits_c - gamma * logits_n
|
| 118 |
+
else:
|
| 119 |
+
logits = logits_c
|
| 120 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 121 |
+
preds = logits.argmax(-1)
|
| 122 |
+
x = torch.where(mask_left, preds, x)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def sample_ar(
|
| 128 |
+
model: CDFTransformer,
|
| 129 |
+
cond: torch.Tensor,
|
| 130 |
+
prefix: torch.Tensor,
|
| 131 |
+
*,
|
| 132 |
+
action: torch.Tensor | None = None,
|
| 133 |
+
max_new: int = 50,
|
| 134 |
+
temperature: float = 1.0,
|
| 135 |
+
gamma: float = 0.0,
|
| 136 |
+
null_cond: int = 0,
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""AR-mode sampling: future tokens at sigma=1, past at sigma=0.
|
| 139 |
+
|
| 140 |
+
Faster than denoise mode when you only want to continue a prefix.
|
| 141 |
+
"""
|
| 142 |
+
cfg = model.cfg
|
| 143 |
+
device = cond.device
|
| 144 |
+
B = cond.size(0)
|
| 145 |
+
x = prefix.clone().to(device)
|
| 146 |
+
if x.dim() == 1: x = x.unsqueeze(0)
|
| 147 |
+
null = torch.full_like(cond, null_cond)
|
| 148 |
+
null_action = (torch.full_like(action, cfg.n_latent_actions)
|
| 149 |
+
if action is not None and cfg.use_latent_action else None)
|
| 150 |
+
|
| 151 |
+
for _ in range(max_new):
|
| 152 |
+
T_now = x.size(1)
|
| 153 |
+
if T_now >= cfg.max_seq_len:
|
| 154 |
+
break
|
| 155 |
+
# Pad with MASK
|
| 156 |
+
x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
|
| 157 |
+
device=device, dtype=torch.long)], dim=1)
|
| 158 |
+
sigma = torch.zeros(B, T_now + 1, device=device)
|
| 159 |
+
sigma[:, -1] = 1.0
|
| 160 |
+
a_pad = None
|
| 161 |
+
if action is not None and cfg.use_latent_action:
|
| 162 |
+
a_pad = torch.cat([action[:, :T_now],
|
| 163 |
+
torch.full((B, 1), cfg.n_latent_actions,
|
| 164 |
+
device=device, dtype=torch.long)], dim=1)
|
| 165 |
+
logits = model(x_pad, sigma, cond, a_pad)
|
| 166 |
+
if gamma > 0:
|
| 167 |
+
logits_n = model(x_pad, sigma, null, null_action)
|
| 168 |
+
logits = (1 + gamma) * logits - gamma * logits_n
|
| 169 |
+
logits[:, :, cfg.mask_token] = -1e9
|
| 170 |
+
p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
|
| 171 |
+
nxt = torch.multinomial(p, 1)
|
| 172 |
+
x = torch.cat([x, nxt], dim=1)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@torch.no_grad()
|
| 177 |
+
def counterfactual_rollout(
|
| 178 |
+
model: CDFTransformer,
|
| 179 |
+
seed_prefix: torch.Tensor,
|
| 180 |
+
treatment_cond: int,
|
| 181 |
+
untreated_cond: int,
|
| 182 |
+
*,
|
| 183 |
+
treatment_action: int | None = None,
|
| 184 |
+
untreated_action: int | None = None,
|
| 185 |
+
n_samples: int = 100,
|
| 186 |
+
gamma: float = 2.0,
|
| 187 |
+
n_steps: int = 32,
|
| 188 |
+
) -> dict:
|
| 189 |
+
"""Sample paired counterfactual trajectories under treatment vs no-treatment.
|
| 190 |
+
|
| 191 |
+
Two ways to specify the intervention:
|
| 192 |
+
- via cond id (cohort-level): treatment_cond / untreated_cond
|
| 193 |
+
- via latent action id (per-token): treatment_action / untreated_action
|
| 194 |
+
"""
|
| 195 |
+
cfg = model.cfg
|
| 196 |
+
device = next(model.parameters()).device
|
| 197 |
+
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
|
| 198 |
+
T = cfg.max_seq_len
|
| 199 |
+
|
| 200 |
+
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
|
| 201 |
+
cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
|
| 202 |
+
|
| 203 |
+
action_tx = action_null = None
|
| 204 |
+
if cfg.use_latent_action:
|
| 205 |
+
action_tx = torch.full((n_samples, T),
|
| 206 |
+
treatment_action if treatment_action is not None
|
| 207 |
+
else cfg.n_latent_actions,
|
| 208 |
+
device=device, dtype=torch.long)
|
| 209 |
+
action_null = torch.full((n_samples, T),
|
| 210 |
+
untreated_action if untreated_action is not None
|
| 211 |
+
else cfg.n_latent_actions,
|
| 212 |
+
device=device, dtype=torch.long)
|
| 213 |
+
|
| 214 |
+
traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
|
| 215 |
+
action=action_tx, gamma=gamma, n_steps=n_steps)
|
| 216 |
+
traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
|
| 217 |
+
action=action_null, gamma=gamma, n_steps=n_steps)
|
| 218 |
+
return {
|
| 219 |
+
"traj_treated": traj_tx, "traj_untreated": traj_null,
|
| 220 |
+
"n": n_samples, "treatment_cond": treatment_cond,
|
| 221 |
+
"untreated_cond": untreated_cond, "gamma": gamma,
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
|
| 226 |
+
if not target_ids:
|
| 227 |
+
return 0.0
|
| 228 |
+
target = torch.tensor(target_ids, device=traj.device)
|
| 229 |
+
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
|
| 230 |
+
return has.float().mean().item()
|
src/wsd_scheduler.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.
|
| 2 |
+
|
| 3 |
+
Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
|
| 4 |
+
Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
|
| 5 |
+
Phase 2 (stable, 60-80%): constant peak_lr
|
| 6 |
+
Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1
|
| 7 |
+
|
| 8 |
+
Beats cosine for:
|
| 9 |
+
- data-limited regimes (we can extend stable phase if loss still falls)
|
| 10 |
+
- continue-pretrain (sharp decay enables clean fine-tune handoff)
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def wsd_lr_schedule(step: int, total_steps: int,
|
| 19 |
+
warmup_steps: int = 500,
|
| 20 |
+
stable_frac: float = 0.80,
|
| 21 |
+
decay_frac: float = 0.15,
|
| 22 |
+
min_lr_ratio: float = 0.1,
|
| 23 |
+
decay_type: str = "linear") -> float:
|
| 24 |
+
"""Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
|
| 25 |
+
if step < warmup_steps:
|
| 26 |
+
return step / max(1, warmup_steps)
|
| 27 |
+
# remainder of steps after warmup
|
| 28 |
+
remaining = total_steps - warmup_steps
|
| 29 |
+
if remaining <= 0:
|
| 30 |
+
return 1.0
|
| 31 |
+
stable_steps = int(stable_frac * remaining)
|
| 32 |
+
decay_steps = int(decay_frac * remaining)
|
| 33 |
+
pos = step - warmup_steps
|
| 34 |
+
if pos < stable_steps:
|
| 35 |
+
return 1.0
|
| 36 |
+
decay_pos = pos - stable_steps
|
| 37 |
+
if decay_pos >= decay_steps:
|
| 38 |
+
return min_lr_ratio
|
| 39 |
+
progress = decay_pos / max(1, decay_steps)
|
| 40 |
+
if decay_type == "linear":
|
| 41 |
+
return 1.0 - (1.0 - min_lr_ratio) * progress
|
| 42 |
+
elif decay_type == "cosine":
|
| 43 |
+
return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
|
| 44 |
+
elif decay_type == "inv_sqrt":
|
| 45 |
+
return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"unknown decay_type: {decay_type}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
|
| 51 |
+
total_steps: int,
|
| 52 |
+
warmup_steps: int = 500,
|
| 53 |
+
stable_frac: float = 0.80,
|
| 54 |
+
decay_frac: float = 0.15,
|
| 55 |
+
min_lr_ratio: float = 0.1,
|
| 56 |
+
decay_type: str = "linear") -> LambdaLR:
|
| 57 |
+
"""Build a LambdaLR scheduler with WSD schedule."""
|
| 58 |
+
def fn(step):
|
| 59 |
+
return wsd_lr_schedule(step, total_steps, warmup_steps,
|
| 60 |
+
stable_frac, decay_frac, min_lr_ratio, decay_type)
|
| 61 |
+
return LambdaLR(optimizer, lr_lambda=fn)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
# Visualize the schedule
|
| 66 |
+
total = 10000
|
| 67 |
+
warmup = 500
|
| 68 |
+
print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
|
| 69 |
+
print(f" step lr_mult")
|
| 70 |
+
for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
|
| 71 |
+
m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
|
| 72 |
+
print(f" {s:>5} {m:.4f}")
|