timmers commited on
Commit
908ea05
·
verified ·
1 Parent(s): 32a27f8

GEMEO/SUS v6 recurrence-aware (RAVEN) — new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.

Browse files
LICENSE ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
2
+
3
+ Copyright (c) 2026 Raras.ai / RarasNet
4
+ Authors: Dimas Quintas Verdial and contributors.
5
+
6
+ You are free to:
7
+ - Share — copy and redistribute the material in any medium or format
8
+ - Adapt — remix, transform, and build upon the material
9
+
10
+ Under the following terms:
11
+ - Attribution — You must give appropriate credit, provide a link to the
12
+ license, and indicate if changes were made.
13
+ - NonCommercial — You may not use the material for commercial purposes.
14
+ - No additional restrictions — You may not apply legal terms or
15
+ technological measures that legally restrict others from doing anything
16
+ the license permits.
17
+
18
+ Full legal text: https://creativecommons.org/licenses/by-nc/4.0/legalcode
19
+
20
+ NOT FOR CLINICAL USE
21
+ ====================
22
+ This model is released for research purposes only. It is NOT a medical
23
+ device. It is NOT approved by ANVISA, FDA, EMA, or any regulatory body.
24
+ Outputs MUST NOT be used to inform diagnosis, treatment, or any clinical
25
+ decision without explicit human physician oversight and applicable
26
+ regulatory clearance.
27
+
28
+ Compliance scope (Brazilian SUS data):
29
+ - LGPD (Lei Geral de Proteção de Dados, Brazil)
30
+ - CNS-hash linkage performed under data-use agreement with DATASUS
31
+ - Resolution CNS 466/2012 + 510/2016 (Brazilian ethics framework)
32
+
33
+ For commercial licensing or clinical deployment partnerships, contact:
34
+ dimas@raras.ai
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language: [pt, en]
4
+ tags:
5
+ - world-model
6
+ - patient-digital-twin
7
+ - rare-disease
8
+ - diffusion-forcing
9
+ - recurrence-aware
10
+ - new-onset-prediction
11
+ - brazilian-sus
12
+ - datasus
13
+ - primekg
14
+ library_name: pytorch
15
+ pipeline_tag: time-series-forecasting
16
+ extra_gated_prompt: >-
17
+ Research only. Not a medical device. No clinical use without physician
18
+ oversight and applicable regulatory clearance.
19
+ extra_gated_fields:
20
+ Name: text
21
+ Affiliation: text
22
+ Intended use: text
23
+ I agree to non-clinical research use only: checkbox
24
+ ---
25
+
26
+ # GEMEO/SUS v6 — Recurrence-Aware World Model (defeats the autocorrelation trap)
27
+
28
+ > The flagship recurrence-aware instance of [**GEMEO Architecture v2.0**](https://huggingface.co/Raras-AI/gemeo-arch).
29
+ > Implements Principle 7 (RAVEN recurrence-weighted loss) and is the first
30
+ > GEMEO instance to **beat a frequency baseline on the genuinely hard
31
+ > new-onset task** — predicting the *first* occurrence of a clinical event,
32
+ > with repeats excluded.
33
+
34
+ **Family:** [`gemeo-arch`](https://huggingface.co/Raras-AI/gemeo-arch) (architecture v2.0) · [`gemeo-sus-v4`](https://huggingface.co/Raras-AI/gemeo-sus-v4) (predecessor) · **`gemeo-sus-v6`** (this, recurrence-aware flagship) · [`gemeo-twin-stack`](https://huggingface.co/Raras-AI/gemeo-twin-stack)
35
+
36
+ ## Why this model exists — the autocorrelation trap
37
+
38
+ GEMEO/SUS v4 reported gap-fill Top-10 = 100%. Rigorous re-evaluation showed
39
+ this was a **metric artifact**: in Brazilian SUS APAC data, **82.2% of events
40
+ are repeats** (a patient on an orphan drug receives the same monthly dispensing
41
+ code), and only **17.8% are first occurrences**. On the genuinely hard
42
+ **new-onset task** (predict the first occurrence, repeats excluded), v4 scored
43
+ Top-1 **23.6% — below a frequency baseline (38.2%)**. This is the documented
44
+ "repeated event tokens inflate metrics" pitfall (RAVEN, arXiv 2603.24562;
45
+ NEP, arXiv 2509.25591).
46
+
47
+ **v6 fixes it.** Following RAVEN, the training loss scales each token by
48
+ `w = max(λ^count, w_min)` (λ = 0.25), so the first occurrence of an event
49
+ carries full weight while the 12th monthly repeat carries ≈ 0. The model is
50
+ forced to learn novelty, not autocorrelation.
51
+
52
+ ## Headline result — new-onset prediction (the reviewer-proof metric)
53
+
54
+ All on the held-out test split (6,374 patients), 95% bootstrap CI, on the
55
+ new-onset subset (first occurrence, repeats excluded; n = 1,730 positions).
56
+
57
+ | Model | new-onset Top-1 | new-onset Top-5 | vs frequency baseline |
58
+ |---|---:|---:|---|
59
+ | Frequency baseline | 38.2% | 90.8% | — |
60
+ | GEMEO/SUS v4 (no recurrence weighting) | 23.6% | 98.4% | **loses** (−14.6 pp) |
61
+ | **GEMEO/SUS v6 (RAVEN λ=0.25)** | **60.1% [57.8, 62.3]** | **98.7%** | **wins (+21.9 pp)** |
62
+
63
+ **+36.5 pp over v4, +21.9 pp over the frequency baseline, non-overlapping CI.**
64
+ This is the decisive evidence that recurrence-aware training defeats the
65
+ autocorrelation trap: on genuinely novel events, the model now beats trivial
66
+ baselines by a wide, statistically clear margin.
67
+
68
+ ## Architecture context — GEMEO v2.0
69
+
70
+ GEMEO v2.0 is a three-pillar architecture (Propose → Simulate → Verify):
71
+
72
+ - **Pillar A — Graph Proposer:** KG zero-shot link prediction emits first-onset candidates the patient has never had (TxGNN / PhenoKG style). *Scoped.*
73
+ - **Pillar B — World-Model Scorer (this model):** Diffusion-Forcing trunk + **RAVEN recurrence-weighted loss (Principle 7, proven here)** + competing-risks hazard head (Principle 8, scoped). *Implemented.*
74
+ - **Pillar C — Swarm Verifier:** multi-agent debate with KG evidence paths (DeepRare-style). *Scoped.*
75
+
76
+ GEMEO v2.0 targets **Level 3 (counterfactual rollout)** on the clinical-world-model capability rubric of Liu et al. (NeurIPS 2025, arXiv 2511.16333) and is engineered to close the four gaps that survey identifies (under-specified action spaces, weak interventional validation, incomplete multimodal state, limited calibration). Full spec: [`gemeo-arch/gemeo_architecture_spec_v2.md`](https://huggingface.co/Raras-AI/gemeo-arch).
77
+
78
+ ## Training recipe
79
+
80
+ - Warm-start from `gemeo-sus-v4` (positional features).
81
+ - RAVEN history-decay loss: λ = 0.25, w_min = 0.02, applied to the per-token
82
+ diffusion-forcing cross-entropy on corrupted positions.
83
+ - 6 epochs, WSD LR (peak 2e-4), bf16, single H100, ~6 min, ≈ $0.50.
84
+ - 19.97M params (same architecture as v4).
85
+
86
+ ## Usage
87
+
88
+ ```python
89
+ import torch, sys; sys.path.append("src")
90
+ import torch.nn as nn
91
+ from diffusion_forcing_v13 import CDFv13Transformer, CDFv13Config
92
+
93
+ class PositionalFeatureEmbed(nn.Module):
94
+ def __init__(self, d):
95
+ super().__init__()
96
+ self.age_proj=nn.Linear(1,d//4); self.year_proj=nn.Linear(1,d//4)
97
+ self.pos_proj=nn.Linear(1,d//4); self.combine=nn.Linear(3*(d//4),d); self.norm=nn.LayerNorm(d)
98
+ def forward(self, ages, years, positions):
99
+ a=ages.clamp(0,100)/100; y=(years-2010).clamp(0,20)/20; p=(positions/512).clamp(0,1)
100
+ e=torch.cat([self.age_proj(a.unsqueeze(-1)), self.year_proj(y.unsqueeze(-1)),
101
+ self.pos_proj(p.unsqueeze(-1))], -1)
102
+ return self.norm(self.combine(e))
103
+
104
+ ck = torch.load("cdf_v6_raven.pt", map_location="cpu", weights_only=False)
105
+ cfg = CDFv13Config(**{k:v for k,v in ck["config"].items() if k in CDFv13Config.__dataclass_fields__})
106
+ model = CDFv13Transformer(cfg); model.load_state_dict(ck["model_state"])
107
+ pfe = PositionalFeatureEmbed(cfg.d_model); pfe.load_state_dict(ck["pos_feat_state"])
108
+ print(f"GEMEO/SUS v6 (RAVEN λ={ck['raven_lambda']}) — new-onset Top-1 60.1%")
109
+ ```
110
+
111
+ ## Honest scope
112
+
113
+ - ✅ **Proven on 52k SUS:** recurrence-aware training defeats autocorrelation (this model).
114
+ - 🔜 **Scoped, feasible on SUS:** KG zero-shot onset proposer (Pillar A), swarm verifier (Pillar C), competing-risks hazard head (Principle 8).
115
+ - 🏥 **Requires Mayo multimodal substrate:** rigorous counterfactual/interventional validation (labs, genomics, imaging, dense timing).
116
+
117
+ ## Citation
118
+
119
+ ```bibtex
120
+ @misc{gemeo_sus_v6_2026,
121
+ title = {GEMEO/SUS v6: Recurrence-Aware Patient World Model for
122
+ New-Onset Prediction in Rare Disease},
123
+ author = {Verdial, Dimas Quintas and the Raras AI team},
124
+ year = {2026},
125
+ url = {https://huggingface.co/Raras-AI/gemeo-sus-v6},
126
+ note = {GEMEO Architecture v2.0, Principle 7. Beats frequency baseline
127
+ on new-onset (60.1% vs 38.2% Top-1).}
128
+ }
129
+ ```
130
+
131
+ ⚠️ **Research only.** Not a medical device. No clinical use without physician oversight and applicable regulatory clearance.
benchmarks/v4_newonset_eval.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gapfill_all": {
3
+ "top1": [
4
+ 0.8634692246203037,
5
+ 0.8600839328537171,
6
+ 0.8667732480682121
7
+ ],
8
+ "top5": [
9
+ 0.9987743138822276,
10
+ 0.9983746336264322,
11
+ 0.9990947242206235
12
+ ],
13
+ "n": 37530
14
+ },
15
+ "gapfill_newonset": {
16
+ "top1": [
17
+ 0.23641618497109826,
18
+ 0.21560693641618497,
19
+ 0.2560693641618497
20
+ ],
21
+ "top5": [
22
+ 0.9838150289017341,
23
+ 0.9780346820809248,
24
+ 0.9895953757225433
25
+ ],
26
+ "top10": [
27
+ 1.0,
28
+ 1.0,
29
+ 1.0
30
+ ],
31
+ "n": 1730
32
+ },
33
+ "gapfill_newonset_baseline": {
34
+ "top1": 0.3815028901734104,
35
+ "top5": 0.9080924855491329,
36
+ "top10": 0.9838150289017341,
37
+ "n": 1730
38
+ },
39
+ "multi_horizon_newonset": {
40
+ "1": {
41
+ "n": 83,
42
+ "top5": [
43
+ 0.5903614457831325,
44
+ 0.4819277108433735,
45
+ 0.6987951807228916
46
+ ],
47
+ "macro_auroc": 0.5721330802852541,
48
+ "n_classes": 2
49
+ },
50
+ "10": {
51
+ "n": 39,
52
+ "top5": [
53
+ 0.7692307692307693,
54
+ 0.6410256410256411,
55
+ 0.8974358974358975
56
+ ],
57
+ "macro_auroc": 0.7896551724137931,
58
+ "n_classes": 1
59
+ },
60
+ "25": {
61
+ "n": 10,
62
+ "top5": [
63
+ 0.3,
64
+ 0.0,
65
+ 0.6
66
+ ],
67
+ "macro_auroc": NaN,
68
+ "n_classes": 0
69
+ },
70
+ "50": {
71
+ "n": 1,
72
+ "top5": [
73
+ 0.0,
74
+ 0.0,
75
+ 0.0
76
+ ],
77
+ "macro_auroc": NaN,
78
+ "n_classes": 0
79
+ }
80
+ }
81
+ }
benchmarks/v6_raven_newonset.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "raven_lambda": 0.25,
3
+ "newonset_model": {
4
+ "top1": [
5
+ 0.6011560693641619,
6
+ 0.5780202312138729,
7
+ 0.623121387283237
8
+ ],
9
+ "top5": [
10
+ 0.9867052023121388,
11
+ 0.9809248554913295,
12
+ 0.991907514450867
13
+ ],
14
+ "top10": [
15
+ 1.0,
16
+ 1.0,
17
+ 1.0
18
+ ],
19
+ "n": 1730
20
+ },
21
+ "newonset_freq_baseline": {
22
+ "top1": 0.3815028901734104,
23
+ "top5": 0.9080924855491329
24
+ },
25
+ "v4_reference": {
26
+ "newonset_top1": 0.236,
27
+ "newonset_baseline_top1": 0.382
28
+ },
29
+ "verdict": "BEATS baseline"
30
+ }
cdf_v6_raven.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:428534584c3f995bfd4af52d59d16572600dcc0f4e0e844f28a838f57d73caa7
3
+ size 79932134
src/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GEMEO-CDF: Causal Diffusion Forcing for clinical trajectories.
2
+
3
+ Three "first in medicine" hooks:
4
+ 1. DIFFUSION FORCING (Chen MIT NeurIPS 2024 → Dreamer 4 Hafner 2025 backbone)
5
+ — independent per-token noise levels unify AR + diffusion + counterfactual
6
+ in ONE loss. Zero clinical port as of May 2026.
7
+
8
+ 2. LATENT ACTION MODEL (Genie / DeepMind 2024)
9
+ — VQ-VAE codebook over (state_t, state_{t+1}) deltas discovers a
10
+ treatment vocabulary without RxNorm/ATC labels. Solves the APAC
11
+ miscoding / sparsity / off-label labelling pain in DATASUS.
12
+
13
+ 3. PROCESS REWARD VERIFIER (o3 / MAI-DxO 2025 pattern)
14
+ — small PRM scores top-K rollouts at inference, returns top-1 +
15
+ uncertainty band. Deliberative trajectory generation, novel in EHR.
16
+
17
+ Modules:
18
+ diffusion_forcing.py — core architecture (per-token noise + block-causal)
19
+ lam.py — Latent Action Model (VQ-VAE codebook)
20
+ train_cdf.py — training loop with diffusion forcing objective
21
+ sample.py — sampling: AR mode / denoise mode / counterfactual
22
+ distill.py — Shortcut Forcing distillation (Dreamer 4)
23
+ prm.py — Process Reward Verifier
24
+ """
25
+ from .diffusion_forcing import CDFTransformer, CDFConfig
26
+ from .lam import LatentActionVQVAE, LAMConfig
27
+ from .train_cdf import train_cdf
28
+
29
+ __all__ = [
30
+ "CDFTransformer", "CDFConfig",
31
+ "LatentActionVQVAE", "LAMConfig",
32
+ "train_cdf",
33
+ ]
src/adaln_zero.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
2
+
3
+ Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
4
+ PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
5
+
6
+ Why for Diffusion Forcing on EHR:
7
+ - Per-token sigma + global cond/action → per-token (scale, shift, gate)
8
+ - Gates init to zero ⇒ block starts as identity ⇒ no catastrophic init
9
+ - Much better CFG (dropped condition path goes through zero gates,
10
+ not corrupting residual stream)
11
+ - DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
12
+
13
+ We fuse THREE conditioning signals:
14
+ - sigma (B, T) per-token noise level → time_emb (B, T, D)
15
+ - cond (B,) cohort-level treatment id → cond_emb (B, D) → broadcast
16
+ - action(B, T) per-token latent action id → action_emb (B, T, D)
17
+
18
+ Combined into c_t (B, T, D) → ConditioningMLP → 6 modulation tensors
19
+ per block. Each block uses them as:
20
+
21
+ h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
22
+ h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
23
+ """
24
+ from __future__ import annotations
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class AdaLNZeroModulator(nn.Module):
30
+ """Generates per-token (scale, shift, gate) for AdaLN-Zero block.
31
+
32
+ Input: fused conditioning vector c (B, T, d_model).
33
+ Output: 6 tensors of shape (B, T, d_model) each:
34
+ (scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
35
+ """
36
+ def __init__(self, d_model: int):
37
+ super().__init__()
38
+ self.modulator = nn.Sequential(
39
+ nn.SiLU(),
40
+ nn.Linear(d_model, 6 * d_model, bias=True),
41
+ )
42
+ # Zero-init for the gate-producing rows (AdaLN-Zero trick)
43
+ # We zero-init ALL outputs initially; gate stays zero so block is identity
44
+ nn.init.zeros_(self.modulator[-1].weight)
45
+ nn.init.zeros_(self.modulator[-1].bias)
46
+
47
+ def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
48
+ # c: (B, T, d_model)
49
+ out = self.modulator(c) # (B, T, 6*d_model)
50
+ return out.chunk(6, dim=-1)
51
+
52
+
53
+ class AdaLNZeroBlock(nn.Module):
54
+ """Transformer block with AdaLN-Zero modulation.
55
+
56
+ Drop-in replacement for the standard pre-norm block. Reads
57
+ pre-computed modulation tensors and applies them around Attn + MLP.
58
+ """
59
+ def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
60
+ rope=None, kg_xattn=None):
61
+ super().__init__()
62
+ self.d_model = d_model
63
+ self.n_heads = n_heads
64
+ self.head_dim = d_model // n_heads
65
+ self.rope = rope
66
+ self.kg_xattn = kg_xattn
67
+
68
+ self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
69
+ self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
70
+ self.proj = nn.Linear(d_model, d_model, bias=False)
71
+ self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
72
+ self.mlp = nn.Sequential(
73
+ nn.Linear(d_model, ffn, bias=False),
74
+ nn.GELU(),
75
+ nn.Linear(ffn, d_model, bias=False),
76
+ )
77
+ self.dropout = nn.Dropout(dropout)
78
+
79
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
80
+ scale_msa, shift_msa, gate_msa,
81
+ scale_mlp, shift_mlp, gate_mlp,
82
+ kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
83
+ import torch.nn.functional as F
84
+ B, T, D = x.shape
85
+ # MSA branch
86
+ h = self.norm1(x) * (1 + scale_msa) + shift_msa
87
+ qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
88
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
89
+ if self.rope is not None:
90
+ q, k = self.rope(q, k, T)
91
+ out = F.scaled_dot_product_attention(
92
+ q, k, v,
93
+ attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
94
+ dropout_p=self.dropout.p if self.training else 0.0,
95
+ )
96
+ out = out.transpose(1, 2).reshape(B, T, D)
97
+ x = x + gate_msa * self.dropout(self.proj(out))
98
+ # KG cross-attention (between MSA and MLP)
99
+ if self.kg_xattn is not None and kg_ctx is not None:
100
+ x = self.kg_xattn(x, kg_ctx)
101
+ # MLP branch
102
+ h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
103
+ x = x + gate_mlp * self.dropout(self.mlp(h))
104
+ return x
105
+
106
+
107
+ class FusedConditioner(nn.Module):
108
+ """Fuse (sigma, cond, action) into one per-token conditioning vector.
109
+
110
+ Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
111
+ """
112
+ def __init__(self, d_model: int, n_conditions: int, n_actions: int,
113
+ use_action: bool = True):
114
+ super().__init__()
115
+ self.d_model = d_model
116
+ self.use_action = use_action
117
+ # Sigma → sinusoidal embedding
118
+ self.sigma_proj = nn.Sequential(
119
+ nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
120
+ )
121
+ self.cond_emb = nn.Embedding(n_conditions, d_model)
122
+ if use_action:
123
+ self.action_emb = nn.Embedding(n_actions + 1, d_model)
124
+ self.fuse = nn.Sequential(
125
+ nn.SiLU(),
126
+ nn.Linear(d_model, d_model),
127
+ )
128
+
129
+ def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
130
+ import math
131
+ half = self.d_model // 2
132
+ freqs = torch.exp(
133
+ -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
134
+ )
135
+ ang = sigma.float().unsqueeze(-1) * freqs
136
+ emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
137
+ return self.sigma_proj(emb)
138
+
139
+ def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
140
+ action: torch.Tensor | None = None) -> torch.Tensor:
141
+ # sigma (B, T) → time_emb (B, T, D)
142
+ time_emb = self.sinusoidal(sigma)
143
+ # cond (B,) → (B, D) → broadcast to (B, T, D)
144
+ cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
145
+ fused = time_emb + cond_emb
146
+ if self.use_action and action is not None:
147
+ fused = fused + self.action_emb(action)
148
+ return self.fuse(fused)
src/diffusion_forcing_v13.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.
2
+
3
+ Per the SOTA audit (May 2026):
4
+ - Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
5
+ - Path A adopted: small from-scratch model + KG adapters + MEDS interop
6
+
7
+ Architecture:
8
+ - 12M backbone params (Chinchilla-respecting for ~20M token corpus)
9
+ - d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
10
+ - SwiGLU MLP (ffn:d_model = 2.67)
11
+ - Tied embeddings (saves ~12M at vocab=32k)
12
+ - Dropout 0.1 everywhere (small-data critical)
13
+ - Block-causal attention (Diffusion Forcing)
14
+ - Per-token sigma noise (independent)
15
+ - GATED KG cross-attention (tanh(α)·xattn, α init=0)
16
+ - Layers 4, 6, 7 (3 of 8)
17
+ - Lets model learn to use KG progressively, doesn't disrupt early loss
18
+ - DF objective + LM-aux loss (joint training, paper-grade)
19
+
20
+ Sources audited:
21
+ - CoMET (Aug 2025): tokens-per-param ratio
22
+ - CLMBR (Stanford): adapter pattern for cross-site transfer
23
+ - MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
24
+ - Genie (DeepMind 2024): gated cross-attention pattern
25
+ - SD3 (Esser 2024): AdaLN-Zero zero-init gates
26
+ """
27
+ from __future__ import annotations
28
+ import math
29
+ from dataclasses import dataclass, field
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ @dataclass
37
+ class CDFv13Config:
38
+ # Vocab + sequence
39
+ vocab_size: int = 32768 # MEDS-derived (will be much smaller in practice)
40
+ mask_token: int = 32767
41
+ max_seq_len: int = 512
42
+ block_size: int = 16
43
+ # Architecture (Chinchilla-correct for ~20M tokens)
44
+ d_model: int = 384
45
+ n_heads: int = 6
46
+ n_layers: int = 8
47
+ ffn: int = 1024 # SwiGLU effective; flag below uses 2 projections
48
+ dropout: float = 0.1
49
+ emb_dropout: float = 0.1
50
+ use_swiglu: bool = True
51
+ use_rmsnorm: bool = True
52
+ tie_embeddings: bool = True
53
+ # Diffusion forcing
54
+ cond_dropout: float = 0.10
55
+ # KG conditioning (GATED adapters)
56
+ use_kg: bool = True
57
+ kg_dim: int = 3072
58
+ kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
59
+ # Latent action
60
+ use_latent_action: bool = False # Dropped per audit (concept shaky)
61
+ n_latent_actions: int = 512
62
+ # Conditioning
63
+ n_conditions: int = 64
64
+
65
+
66
+ class RMSNorm(nn.Module):
67
+ """Root-mean-square LayerNorm (LLaMA/Mistral style)."""
68
+ def __init__(self, d: int, eps: float = 1e-6):
69
+ super().__init__()
70
+ self.weight = nn.Parameter(torch.ones(d))
71
+ self.eps = eps
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
75
+ return (norm * self.weight.float()).to(x.dtype)
76
+
77
+
78
+ class SwiGLU(nn.Module):
79
+ """SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
80
+ def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
81
+ super().__init__()
82
+ self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
83
+ self.w_up = nn.Linear(d_in, d_hidden, bias=False)
84
+ self.w_down = nn.Linear(d_hidden, d_in, bias=False)
85
+ self.dropout = nn.Dropout(dropout)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
89
+
90
+
91
+ class RotaryEmbedding(nn.Module):
92
+ """RoPE (Su et al. 2021)."""
93
+ def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
94
+ super().__init__()
95
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
96
+ t = torch.arange(max_seq).float()
97
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
98
+ emb = torch.cat([freqs, freqs], dim=-1)
99
+ self.register_buffer("cos", emb.cos(), persistent=False)
100
+ self.register_buffer("sin", emb.sin(), persistent=False)
101
+
102
+ def forward(self, q, k, seq_len):
103
+ cos = self.cos[:seq_len].to(q.dtype).to(q.device)
104
+ sin = self.sin[:seq_len].to(q.dtype).to(q.device)
105
+ def rot_half(x):
106
+ x1, x2 = x.chunk(2, dim=-1)
107
+ return torch.cat([-x2, x1], dim=-1)
108
+ return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
109
+
110
+
111
+ class PerTokenSigmaEmbed(nn.Module):
112
+ """Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
113
+ def __init__(self, d: int):
114
+ super().__init__()
115
+ self.d = d
116
+ self.proj = nn.Sequential(
117
+ nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
118
+ )
119
+
120
+ def forward(self, sigma: torch.Tensor) -> torch.Tensor:
121
+ half = self.d // 2
122
+ freqs = torch.exp(
123
+ -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
124
+ )
125
+ ang = sigma.float().unsqueeze(-1) * freqs
126
+ emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
127
+ return self.proj(emb)
128
+
129
+
130
+ class GatedKGCrossAttention(nn.Module):
131
+ """Cross-attention to KG ego-subgraph, with GATED output.
132
+
133
+ `tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
134
+ initialized to 0. This means at init the cross-attention contributes
135
+ NOTHING to the residual stream, so the model trains identically to
136
+ no-KG until it discovers KG is useful. Prevents catastrophic loss
137
+ spikes on small data.
138
+
139
+ Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
140
+ """
141
+ def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
142
+ super().__init__()
143
+ self.n_heads = n_heads
144
+ self.head_dim = d_model // n_heads
145
+ # Project KG to d_model (run inline so we don't need separate KGProjector module)
146
+ self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
147
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
148
+ self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
149
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
150
+ self.norm_q = RMSNorm(d_model)
151
+ self.norm_kv = RMSNorm(d_model)
152
+ self.dropout = nn.Dropout(dropout)
153
+ # Gate (scalar per block, init=0)
154
+ self.alpha = nn.Parameter(torch.zeros(1))
155
+
156
+ def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
157
+ """
158
+ x_seq: (B, T, d_model)
159
+ kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072)
160
+ """
161
+ B, T, D = x_seq.shape
162
+ kg_proj = self.kg_in_proj(kg_raw) # (B, N_kg, D)
163
+ N_kg = kg_proj.size(1)
164
+ q = self.q_proj(self.norm_q(x_seq))
165
+ kv = self.kv_proj(self.norm_kv(kg_proj))
166
+ k, v = kv.chunk(2, dim=-1)
167
+ q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
168
+ k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
169
+ v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
170
+ out = F.scaled_dot_product_attention(
171
+ q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
172
+ out = out.transpose(1, 2).reshape(B, T, D)
173
+ gate = torch.tanh(self.alpha)
174
+ return x_seq + gate * self.dropout(self.out_proj(out))
175
+
176
+
177
+ class CDFv13Block(nn.Module):
178
+ """Pre-norm transformer block + optional gated KG cross-attn."""
179
+ def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
180
+ layer_idx: int):
181
+ super().__init__()
182
+ self.cfg = cfg
183
+ self.rope = rope
184
+ self.layer_idx = layer_idx
185
+ norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
186
+ self.norm1 = norm_cls(cfg.d_model)
187
+ self.norm2 = norm_cls(cfg.d_model)
188
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
189
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
190
+ if cfg.use_swiglu:
191
+ self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
192
+ else:
193
+ self.mlp = nn.Sequential(
194
+ nn.Linear(cfg.d_model, cfg.ffn, bias=False),
195
+ nn.GELU(),
196
+ nn.Linear(cfg.ffn, cfg.d_model, bias=False),
197
+ nn.Dropout(cfg.dropout),
198
+ )
199
+ self.dropout = nn.Dropout(cfg.dropout)
200
+ self.head_dim = cfg.d_model // cfg.n_heads
201
+ # Gated KG cross-attention (only in specified layers)
202
+ self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
203
+ if self.use_kg_in_layer:
204
+ self.kg_xattn = GatedKGCrossAttention(
205
+ cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
206
+
207
+ def forward(self, x, attn_mask, kg_raw=None):
208
+ B, T, D = x.shape
209
+ # MSA
210
+ h = self.norm1(x)
211
+ qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
212
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
213
+ q, k = self.rope(q, k, T)
214
+ out = F.scaled_dot_product_attention(
215
+ q, k, v,
216
+ attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
217
+ dropout_p=self.cfg.dropout if self.training else 0.0,
218
+ )
219
+ out = out.transpose(1, 2).reshape(B, T, D)
220
+ x = x + self.dropout(self.proj(out))
221
+ # Gated KG cross-attn (if enabled at this layer)
222
+ if self.use_kg_in_layer and kg_raw is not None:
223
+ x = self.kg_xattn(x, kg_raw)
224
+ # MLP
225
+ x = x + self.mlp(self.norm2(x))
226
+ return x
227
+
228
+
229
+ class CDFv13Transformer(nn.Module):
230
+ """Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""
231
+
232
+ def __init__(self, cfg: CDFv13Config | None = None):
233
+ super().__init__()
234
+ self.cfg = cfg or CDFv13Config()
235
+ c = self.cfg
236
+ norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm
237
+
238
+ self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
239
+ self.emb_dropout = nn.Dropout(c.emb_dropout)
240
+
241
+ # Per-token sigma embedding (additive)
242
+ self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
243
+ # Global condition embedding (additive, broadcast)
244
+ self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
245
+
246
+ # RoPE
247
+ self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)
248
+
249
+ # Blocks
250
+ self.blocks = nn.ModuleList([
251
+ CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
252
+ ])
253
+ self.final_norm = norm_cls(c.d_model)
254
+ self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
255
+ if c.tie_embeddings:
256
+ self.head.weight = self.tok_emb.weight
257
+
258
+ # Block-causal mask buffer
259
+ T = c.max_seq_len
260
+ block_id = torch.arange(T) // c.block_size
261
+ mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
262
+ self.register_buffer("block_mask", mask, persistent=False)
263
+
264
+ # Init
265
+ self.apply(self._init_weights)
266
+
267
+ def _init_weights(self, m):
268
+ if isinstance(m, nn.Linear):
269
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
270
+ if m.bias is not None: nn.init.zeros_(m.bias)
271
+ elif isinstance(m, nn.Embedding):
272
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
273
+
274
+ def forward(self, x, sigma, cond, kg_raw=None):
275
+ B, T = x.shape
276
+ h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
277
+ h = self.emb_dropout(h)
278
+ mask = self.block_mask[:T, :T]
279
+ for blk in self.blocks:
280
+ h = blk(h, mask, kg_raw=kg_raw)
281
+ h = self.final_norm(h)
282
+ return self.head(h)
283
+
284
+ def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
285
+ mode: str = "uniform") -> torch.Tensor:
286
+ """Standard absorbing-state DF loss with per-token sigma.
287
+
288
+ mode: 'uniform' (default — safer for discrete than logit-normal per audit)
289
+ 'logit_normal' (SD3-style — keep as ablation only)
290
+ """
291
+ B, T = x_clean.shape
292
+ device = x_clean.device
293
+ # CFG cond dropout
294
+ drop = torch.rand(B, device=device) < self.cfg.cond_dropout
295
+ cond = torch.where(drop, torch.zeros_like(cond), cond)
296
+ if kg_raw is not None:
297
+ drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
298
+ kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
299
+ # Sample per-token sigma
300
+ if mode == "logit_normal":
301
+ sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
302
+ else:
303
+ sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
304
+ # Absorbing-state corruption
305
+ corrupt = torch.rand(B, T, device=device) < sigma
306
+ x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
307
+ logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
308
+ ce = F.cross_entropy(
309
+ logits.reshape(-1, self.cfg.vocab_size),
310
+ x_clean.reshape(-1),
311
+ reduction="none",
312
+ ).reshape(B, T)
313
+ n = corrupt.float().sum().clamp(min=1.0)
314
+ return (ce * corrupt.float()).sum() / n
src/eval_sota.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SOTA evaluation suite for CDFv13 — audit-proof.
2
+
3
+ Per the May 2026 SOTA audit, replaces "Top-1 mid-position" (not recognized)
4
+ with the canonical EHR foundation model metric stack:
5
+
6
+ Classification (next-event, downstream tasks):
7
+ - AUROC + AUPRC + Brier
8
+ - Calibration: ICI (Austin & Steyerberg 2019)
9
+ - Decision-curve analysis (Vickers)
10
+ - Bootstrap 95% CI (≥2000 resamples) — required for rare disease
11
+
12
+ Survival (DATASUS SIM mortality):
13
+ - Uno's C (concordance_index_ipcw) — preferred over Harrell at high censoring
14
+ - Integrated Brier Score (1/3/5y)
15
+ - Time-dependent AUC
16
+
17
+ Counterfactual / causal:
18
+ - ATE with bootstrap CI
19
+ - E-value (VanderWeele)
20
+ - Negative-control outcome + exposure
21
+ - Tipping-point analysis
22
+
23
+ Generation fidelity (CoMET / SynthEHRella):
24
+ - Dim-wise probability match
25
+ - MMD (Maximum Mean Discrepancy) with RBF kernel
26
+ - TSTR (Train-on-Synthetic-Test-on-Real)
27
+
28
+ Subgroup fairness (npj DM requirement):
29
+ - Stratified metrics: sex, age band, UF region
30
+
31
+ Split strategy (DATASUS rare disease):
32
+ - Temporal: train ≤2022, val 2023, test 2024-2025
33
+ - Geographic: train SE+S, test N+NE (UF cross-region = "external")
34
+ - Patient-level 5-fold CV (variance estimation)
35
+ """
36
+ from __future__ import annotations
37
+ import math
38
+ import numpy as np
39
+ import torch
40
+ from typing import Callable
41
+
42
+
43
+ # ---------- Classification ----------
44
+
45
+ def auroc(y: np.ndarray, p: np.ndarray) -> float:
46
+ from sklearn.metrics import roc_auc_score
47
+ if len(np.unique(y)) < 2: return float("nan")
48
+ return roc_auc_score(y, p)
49
+
50
+
51
+ def auprc(y: np.ndarray, p: np.ndarray) -> float:
52
+ from sklearn.metrics import average_precision_score
53
+ if len(np.unique(y)) < 2: return float("nan")
54
+ return average_precision_score(y, p)
55
+
56
+
57
+ def brier(y: np.ndarray, p: np.ndarray) -> float:
58
+ from sklearn.metrics import brier_score_loss
59
+ return brier_score_loss(y, p)
60
+
61
+
62
+ def ici(y: np.ndarray, p: np.ndarray, frac: float = 0.75) -> float:
63
+ """Integrated Calibration Index (Austin & Steyerberg 2019).
64
+ Lowess-smoothed deviation from perfect calibration.
65
+ """
66
+ from statsmodels.nonparametric.smoothers_lowess import lowess
67
+ sm = lowess(y, p, frac=frac, return_sorted=True)
68
+ return float(np.mean(np.abs(sm[:, 1] - sm[:, 0])))
69
+
70
+
71
+ def net_benefit(y: np.ndarray, p: np.ndarray, threshold: float) -> float:
72
+ """Net benefit at a given decision threshold (Vickers DCA)."""
73
+ tp = ((p >= threshold) & (y == 1)).sum()
74
+ fp = ((p >= threshold) & (y == 0)).sum()
75
+ n = len(y)
76
+ if threshold >= 1.0: return 0.0
77
+ return tp / n - (fp / n) * (threshold / (1 - threshold))
78
+
79
+
80
+ def decision_curve(y: np.ndarray, p: np.ndarray,
81
+ thresholds: list[float] = None) -> dict:
82
+ """Decision-curve analysis: net benefit across thresholds vs treat-all/treat-none."""
83
+ if thresholds is None:
84
+ thresholds = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
85
+ model_nb = [net_benefit(y, p, t) for t in thresholds]
86
+ treat_all_nb = [(y.mean()) - (1 - y.mean()) * (t / (1 - t)) if t < 1 else 0
87
+ for t in thresholds]
88
+ treat_none_nb = [0.0] * len(thresholds)
89
+ return {
90
+ "thresholds": thresholds,
91
+ "model": model_nb,
92
+ "treat_all": treat_all_nb,
93
+ "treat_none": treat_none_nb,
94
+ }
95
+
96
+
97
+ def bootstrap_ci(y: np.ndarray, p: np.ndarray, metric_fn: Callable,
98
+ n_boot: int = 2000, seed: int = 0,
99
+ ci: tuple[float, float] = (2.5, 97.5)) -> tuple[float, float, float]:
100
+ """Bootstrap 95% CI for any (y, p) -> scalar metric."""
101
+ rng = np.random.default_rng(seed)
102
+ n = len(y)
103
+ stats = []
104
+ for _ in range(n_boot):
105
+ idx = rng.integers(0, n, n)
106
+ if len(np.unique(y[idx])) < 2: continue
107
+ try:
108
+ stats.append(metric_fn(y[idx], p[idx]))
109
+ except Exception:
110
+ continue
111
+ if not stats: return (float("nan"),) * 3
112
+ return (
113
+ float(np.percentile(stats, ci[0])),
114
+ float(np.median(stats)),
115
+ float(np.percentile(stats, ci[1])),
116
+ )
117
+
118
+
119
+ # ---------- Survival ----------
120
+
121
+ def uno_c_index(y_train_event, y_train_time, y_test_event, y_test_time,
122
+ risk_score, tau: float = None) -> float:
123
+ """Uno's C-index (IPCW concordance), preferred at high censoring.
124
+ Requires scikit-survival.
125
+ """
126
+ try:
127
+ from sksurv.metrics import concordance_index_ipcw
128
+ except ImportError:
129
+ return float("nan")
130
+ # Build structured arrays
131
+ y_train = np.array(
132
+ list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
133
+ dtype=[("event", "?"), ("time", "<f8")],
134
+ )
135
+ y_test = np.array(
136
+ list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
137
+ dtype=[("event", "?"), ("time", "<f8")],
138
+ )
139
+ if tau is None:
140
+ tau = float(y_test_time.max()) * 0.95
141
+ c, *_ = concordance_index_ipcw(y_train, y_test, risk_score, tau=tau)
142
+ return float(c)
143
+
144
+
145
+ def integrated_brier_score(y_train_event, y_train_time, y_test_event, y_test_time,
146
+ surv_pred: np.ndarray, times: np.ndarray) -> float:
147
+ """Integrated Brier Score (lower is better)."""
148
+ try:
149
+ from sksurv.metrics import integrated_brier_score as ibs_fn
150
+ except ImportError:
151
+ return float("nan")
152
+ y_train = np.array(
153
+ list(zip(y_train_event.astype(bool), y_train_time.astype(float))),
154
+ dtype=[("event", "?"), ("time", "<f8")],
155
+ )
156
+ y_test = np.array(
157
+ list(zip(y_test_event.astype(bool), y_test_time.astype(float))),
158
+ dtype=[("event", "?"), ("time", "<f8")],
159
+ )
160
+ return float(ibs_fn(y_train, y_test, surv_pred, times))
161
+
162
+
163
+ # ---------- Causal / Counterfactual ----------
164
+
165
+ def e_value(rr: float) -> float:
166
+ """E-value (VanderWeele & Ding 2017): min strength of unmeasured
167
+ confounder needed to explain away an observed RR.
168
+ """
169
+ rr = max(rr, 1e-9)
170
+ if rr >= 1.0:
171
+ return rr + math.sqrt(rr * (rr - 1))
172
+ rr_inv = 1.0 / rr
173
+ return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
174
+
175
+
176
+ def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
177
+ """Negative-control outcome: ATE on a control outcome should be ~0."""
178
+ return abs(nc_ate) < threshold
179
+
180
+
181
+ def tipping_point(observed_effect: float, ci_half_width: float) -> float:
182
+ """How much would unmeasured confounding need to shift effect to nullify?"""
183
+ if abs(observed_effect) <= ci_half_width:
184
+ return 0.0
185
+ return float(abs(observed_effect) - ci_half_width)
186
+
187
+
188
+ # ---------- Generation fidelity (SynthEHRella triad) ----------
189
+
190
+ def dim_wise_probability(real_seq: torch.Tensor, synth_seq: torch.Tensor,
191
+ vocab_size: int) -> float:
192
+ """Compare per-token Bernoulli rates between real and synthetic batches.
193
+
194
+ Returns mean abs difference (lower = closer match).
195
+ """
196
+ real_one_hot = F.one_hot(real_seq, vocab_size).float().mean(dim=(0, 1))
197
+ synth_one_hot = F.one_hot(synth_seq, vocab_size).float().mean(dim=(0, 1))
198
+ return float((real_one_hot - synth_one_hot).abs().mean())
199
+
200
+
201
+ def mmd_rbf(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> float:
202
+ """Maximum Mean Discrepancy with RBF kernel.
203
+
204
+ x, y: (B, D) flattened embeddings. Returns MMD^2 (lower = closer).
205
+ """
206
+ def rbf(a, b):
207
+ d = (a.unsqueeze(1) - b.unsqueeze(0)).pow(2).sum(-1)
208
+ return torch.exp(-d / (2 * sigma ** 2))
209
+ return float(rbf(x, x).mean() + rbf(y, y).mean() - 2 * rbf(x, y).mean())
210
+
211
+
212
+ # ---------- Subgroup fairness ----------
213
+
214
+ def stratified_metrics(y: np.ndarray, p: np.ndarray,
215
+ groups: np.ndarray,
216
+ metric_fn: Callable = auroc) -> dict[str, float]:
217
+ """Compute metric per subgroup (sex, age band, UF region)."""
218
+ out = {}
219
+ for g in np.unique(groups):
220
+ mask = groups == g
221
+ if mask.sum() > 10:
222
+ try:
223
+ out[str(g)] = metric_fn(y[mask], p[mask])
224
+ except Exception:
225
+ out[str(g)] = float("nan")
226
+ return out
227
+
228
+
229
+ # ---------- DATASUS split strategies ----------
230
+
231
+ def temporal_split(events: list[dict], train_until: int = 2022,
232
+ val_year: int = 2023):
233
+ """Temporal split for DATASUS: train ≤2022, val 2023, test 2024+."""
234
+ train, val, test = [], [], []
235
+ for e in events:
236
+ y = e.get("year") or 2020
237
+ if y <= train_until: train.append(e)
238
+ elif y == val_year: val.append(e)
239
+ else: test.append(e)
240
+ return train, val, test
241
+
242
+
243
+ def geographic_split(patients: list[dict], external_ufs: set = None):
244
+ """Geographic split: train on SE+S, test on N+NE.
245
+ For DATASUS this is the closest analog to "external validation."
246
+ """
247
+ if external_ufs is None:
248
+ external_ufs = {"AC", "AL", "AP", "AM", "BA", "CE", "MA", "PA",
249
+ "PB", "PE", "PI", "RN", "SE", "TO", "RR", "RO"}
250
+ train, test = [], []
251
+ for p in patients:
252
+ uf = next((e.get("uf_code") for e in p.get("events", []) if e.get("uf_code")),
253
+ None)
254
+ (test if uf in external_ufs else train).append(p)
255
+ return train, test
256
+
257
+
258
+ # ---------- Combined eval report ----------
259
+
260
+ def full_eval_report(y: np.ndarray, p: np.ndarray,
261
+ groups_sex: np.ndarray = None,
262
+ groups_age: np.ndarray = None,
263
+ groups_uf: np.ndarray = None,
264
+ n_boot: int = 2000) -> dict:
265
+ """Generate a full audit-proof report for a binary classification task.
266
+
267
+ Returns a dict with point estimates + bootstrap CIs + DCA + fairness.
268
+ """
269
+ import torch.nn.functional as F # local import to keep top clean
270
+
271
+ auroc_lo, auroc_med, auroc_hi = bootstrap_ci(y, p, auroc, n_boot)
272
+ auprc_lo, auprc_med, auprc_hi = bootstrap_ci(y, p, auprc, n_boot)
273
+ brier_lo, brier_med, brier_hi = bootstrap_ci(y, p, brier, n_boot)
274
+
275
+ report = {
276
+ "n_eval": len(y),
277
+ "prevalence": float(y.mean()),
278
+ "auroc": {"point": auroc(y, p), "ci95": [auroc_lo, auroc_hi], "median": auroc_med},
279
+ "auprc": {"point": auprc(y, p), "ci95": [auprc_lo, auprc_hi], "median": auprc_med},
280
+ "brier": {"point": brier(y, p), "ci95": [brier_lo, brier_hi], "median": brier_med},
281
+ "ici": ici(y, p),
282
+ "decision_curve": decision_curve(y, p),
283
+ }
284
+ if groups_sex is not None:
285
+ report["fairness_sex"] = stratified_metrics(y, p, groups_sex, auroc)
286
+ if groups_age is not None:
287
+ report["fairness_age"] = stratified_metrics(y, p, groups_age, auroc)
288
+ if groups_uf is not None:
289
+ report["fairness_uf"] = stratified_metrics(y, p, groups_uf, auroc)
290
+ return report
src/meds_export.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MEDS v0.4.1 exporter for DATASUS — audit-proof, interop-ready.
2
+
3
+ Verified against:
4
+ - meds 0.4.1 schemas (DataSchema, CodeMetadataSchema)
5
+ - https://github.com/Medical-Event-Data-Standard/meds
6
+ - CLMBR/MOTOR/EHRSHOT/CoMET tokenization conventions
7
+
8
+ Code conventions (interop-compatible):
9
+ - Static (time=None): GENDER//, RACE//, UF//, MUN//, ORPHA//
10
+ - Birth/Death: MEDS_BIRTH, MEDS_DEATH (reserved)
11
+ - Diagnoses: ICD10//<cid> (NOT CID10// — interop with OHDSI/Athena)
12
+ - Hospitalization: SIH//ADM, SIH//DIS (numeric_value=LOS_days on DIS)
13
+ - Procedures: SIGTAP//<10-digit> (Brazil-local namespace)
14
+ - Drugs (APAC): APAC//<sigtap> (numeric_value=monthly_cost_brl)
15
+ - Outpatient (BPA-I): BPAI//<sigtap>
16
+ - Visits: Visit//{IP, OP, ER} (matches CLMBR convention)
17
+
18
+ Outputs canonical MEDS dataset:
19
+ /out/
20
+ ├── data/ # parquet shards by subject
21
+ │ ├── shard_0.parquet
22
+ │ └── ...
23
+ ├── metadata/
24
+ │ └── codes.parquet # REQUIRED: every unique code with description + parent_codes
25
+ └── dataset_metadata.json # MEDS dataset metadata
26
+ """
27
+ from __future__ import annotations
28
+ import os
29
+ import json
30
+ import logging
31
+ from collections import defaultdict, Counter
32
+ from datetime import datetime
33
+ from typing import Iterator
34
+
35
+ import pyarrow as pa
36
+ import pyarrow.parquet as pq
37
+ import meds
38
+
39
+ log = logging.getLogger("gemeo.cdf.meds_export")
40
+
41
+
42
+ def _parse_date(s) -> datetime | None:
43
+ """Parse date string from various DATASUS formats."""
44
+ if s is None: return None
45
+ s = str(s).strip()
46
+ if not s or s in ("0", "None", "nan"): return None
47
+ try:
48
+ if "-" in s:
49
+ return datetime.strptime(s[:10], "%Y-%m-%d")
50
+ if len(s) == 8:
51
+ return datetime.strptime(s, "%Y%m%d")
52
+ except ValueError:
53
+ return None
54
+ return None
55
+
56
+
57
+ def _ym(year, month) -> datetime | None:
58
+ if year is None: return None
59
+ try:
60
+ return datetime(int(year), int(month) if month else 1, 1)
61
+ except (ValueError, TypeError):
62
+ return None
63
+
64
+
65
+ def datasus_patient_to_meds_rows(p: dict, subject_id: int) -> list[tuple]:
66
+ """Convert one DATASUS patient trajectory to a list of MEDS rows.
67
+
68
+ Each row is (subject_id, time, code, numeric_value, text_value).
69
+ Returns rows ready to write to a parquet shard.
70
+ """
71
+ rows = []
72
+
73
+ # ---- Static (time=None) ----
74
+ if p.get("sex"):
75
+ rows.append((subject_id, None, f"GENDER//{p['sex']}", None, None))
76
+ # ORPHA is rare-disease specific (parallel to ICD10)
77
+ for orpha in p.get("orphas", []):
78
+ rows.append((subject_id, None, f"ORPHA//{orpha}", None, None))
79
+
80
+ # ---- Birth (use birth_year as Jan 1) ----
81
+ birth_year = p.get("birth_year")
82
+ birth_dt = datetime(int(birth_year), 1, 1) if birth_year else None
83
+ if birth_dt:
84
+ rows.append((subject_id, birth_dt, "MEDS_BIRTH", None, None))
85
+
86
+ # ---- Events ----
87
+ for e in p.get("events", []):
88
+ et = e.get("type")
89
+
90
+ if et == "admission": # SIH-RD
91
+ t = _ym(e.get("year"), e.get("month")) or _parse_date(e.get("admission_date"))
92
+ if not t: continue
93
+ rows.append((subject_id, t, "SIH//ADM", None, None))
94
+ rows.append((subject_id, t, "Visit//IP", None, None))
95
+ cid = e.get("cid_princ", "")
96
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
97
+ proc = e.get("primary_procedure")
98
+ if proc: rows.append((subject_id, t, f"SIGTAP//{proc[:10]}", None, None))
99
+ los = e.get("los_days")
100
+ disch_dt = _parse_date(e.get("discharge_date")) or t
101
+ if e.get("death_during_stay"):
102
+ rows.append((subject_id, disch_dt, "MEDS_DEATH", None, None))
103
+ else:
104
+ rows.append((subject_id, disch_dt, "SIH//DIS",
105
+ float(los) if los is not None else None, None))
106
+
107
+ elif et == "treatment": # APAC-SIA orphan drug
108
+ t = _ym(e.get("year"), e.get("month"))
109
+ if not t: continue
110
+ cid = e.get("cid", "")
111
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
112
+ proc = e.get("procedure_code", "")[:10]
113
+ if proc:
114
+ cost = e.get("monthly_cost_brl")
115
+ rows.append((subject_id, t, f"APAC//{proc}",
116
+ float(cost) if cost is not None else None, None))
117
+
118
+ elif et == "outpatient_proc": # BPA-I
119
+ t = _parse_date(e.get("auth_date")) or _ym(e.get("year"), e.get("month"))
120
+ if not t: continue
121
+ cid = e.get("cid", "")
122
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
123
+ proc = e.get("procedure_code", "")[:10]
124
+ if proc:
125
+ rows.append((subject_id, t, f"BPAI//{proc}", None, None))
126
+
127
+ elif et == "death": # SIM
128
+ t = _parse_date(e.get("date_of_death")) or _ym(e.get("year"), e.get("month"))
129
+ if not t: continue
130
+ rows.append((subject_id, t, "MEDS_DEATH", None, None))
131
+ cid = (e.get("cause_cid") or e.get("cid_princ") or e.get("cid", ""))
132
+ if cid: rows.append((subject_id, t, f"ICD10//{cid}", None, None))
133
+
134
+ # Sort: nulls first (static), then by time
135
+ rows.sort(key=lambda r: (r[1] is not None, r[1] or datetime(1900, 1, 1)))
136
+ return rows
137
+
138
+
139
+ def export_to_meds(patients: list[dict], out_dir: str,
140
+ shard_size: int = 5000,
141
+ dataset_name: str = "GEMEO-DATASUS",
142
+ version: str = "v13"):
143
+ """Export a list of DATASUS patient trajectories to MEDS v0.4.1 format.
144
+
145
+ Parameters
146
+ ----------
147
+ patients : list of dict
148
+ Each dict must have: patient_id, sex, birth_year, orphas (list),
149
+ events (list of dicts with 'type', 'year', 'month', etc.)
150
+ out_dir : str
151
+ Output directory (will create data/ and metadata/ subdirs)
152
+ shard_size : int
153
+ Number of subjects per parquet shard
154
+ """
155
+ os.makedirs(f"{out_dir}/data", exist_ok=True)
156
+ os.makedirs(f"{out_dir}/metadata", exist_ok=True)
157
+
158
+ log.info(f"Exporting {len(patients)} patients to MEDS at {out_dir}")
159
+
160
+ # Map patient_id (string hash) → int64 subject_id (MEDS requires int64)
161
+ pid_to_sid = {p["patient_id"]: i for i, p in enumerate(patients)}
162
+
163
+ # ---- Stream rows ----
164
+ all_codes = Counter()
165
+ shard_idx = 0
166
+ shard_rows = []
167
+ n_events = 0
168
+ n_subjects = 0
169
+
170
+ for p in patients:
171
+ sid = pid_to_sid[p["patient_id"]]
172
+ rows = datasus_patient_to_meds_rows(p, sid)
173
+ shard_rows.extend(rows)
174
+ n_events += len(rows)
175
+ n_subjects += 1
176
+ for r in rows:
177
+ all_codes[r[2]] += 1
178
+ # Write shard when full
179
+ if n_subjects % shard_size == 0 and shard_rows:
180
+ _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
181
+ shard_idx += 1
182
+ shard_rows = []
183
+
184
+ # Write remaining
185
+ if shard_rows:
186
+ _write_shard(shard_rows, f"{out_dir}/data/shard_{shard_idx}.parquet")
187
+
188
+ log.info(f" wrote {shard_idx + 1} data shards, {n_events} rows, {n_subjects} subjects")
189
+
190
+ # ---- codes.parquet (REQUIRED in MEDS v0.4) ----
191
+ code_rows = []
192
+ for code, count in all_codes.most_common():
193
+ # parent_codes: empty for Brazil-local namespaces; populated for ICD10 -> SNOMED if mapped
194
+ parent_codes = _get_parent_codes(code)
195
+ code_rows.append({
196
+ "code": code,
197
+ "description": _get_description(code, count),
198
+ "parent_codes": parent_codes,
199
+ })
200
+ code_table = pa.Table.from_pylist(code_rows, schema=meds.CodeMetadataSchema.schema())
201
+ pq.write_table(code_table, f"{out_dir}/metadata/codes.parquet")
202
+ log.info(f" wrote metadata/codes.parquet ({len(code_rows)} unique codes)")
203
+
204
+ # ---- dataset_metadata.json ----
205
+ md = {
206
+ "dataset_name": dataset_name,
207
+ "dataset_version": version,
208
+ "etl_name": "gemeo.cdf.meds_export",
209
+ "etl_version": "1.0.0",
210
+ "meds_version": meds.__version__,
211
+ "n_subjects": n_subjects,
212
+ "n_events": n_events,
213
+ "n_unique_codes": len(all_codes),
214
+ "top_codes": dict(all_codes.most_common(30)),
215
+ }
216
+ with open(f"{out_dir}/dataset_metadata.json", "w") as f:
217
+ json.dump(md, f, indent=2, default=str)
218
+ log.info(f" wrote dataset_metadata.json")
219
+
220
+ return md
221
+
222
+
223
+ def _write_shard(rows: list[tuple], path: str):
224
+ """Write a list of (subject_id, time, code, numeric_value, text_value) to parquet."""
225
+ if not rows: return
226
+ # Build columnar arrays
227
+ subject_id = pa.array([r[0] for r in rows], type=pa.int64())
228
+ time = pa.array([r[1] for r in rows], type=pa.timestamp("us"))
229
+ code = pa.array([r[2] for r in rows], type=pa.string())
230
+ numeric_value = pa.array([r[3] for r in rows], type=pa.float32())
231
+ text_value = pa.array([r[4] for r in rows], type=pa.large_string())
232
+ table = pa.Table.from_arrays(
233
+ [subject_id, time, code, numeric_value, text_value],
234
+ names=["subject_id", "time", "code", "numeric_value", "text_value"],
235
+ )
236
+ # Validate against MEDS schema
237
+ expected_schema = meds.DataSchema.schema()
238
+ # Cast if needed
239
+ table = table.cast(expected_schema, safe=False)
240
+ pq.write_table(table, path, compression="zstd")
241
+
242
+
243
+ # Brazilian-specific mapping tables (extend as needed)
244
+ ICD10_CHAPTERS = {
245
+ "A": "Certain infectious and parasitic diseases",
246
+ "B": "Certain infectious and parasitic diseases",
247
+ "C": "Neoplasms",
248
+ "D": "Neoplasms / Diseases of the blood and immune",
249
+ "E": "Endocrine, nutritional and metabolic diseases",
250
+ "F": "Mental, Behavioral and Neurodevelopmental disorders",
251
+ "G": "Diseases of the nervous system",
252
+ "H": "Diseases of the eye / ear",
253
+ "I": "Diseases of the circulatory system",
254
+ "J": "Diseases of the respiratory system",
255
+ "K": "Diseases of the digestive system",
256
+ "L": "Diseases of the skin and subcutaneous tissue",
257
+ "M": "Diseases of the musculoskeletal system",
258
+ "N": "Diseases of the genitourinary system",
259
+ "O": "Pregnancy, childbirth and the puerperium",
260
+ "P": "Certain conditions originating in the perinatal period",
261
+ "Q": "Congenital malformations, deformations and chromosomal abnormalities",
262
+ "R": "Symptoms, signs and abnormal clinical and laboratory findings",
263
+ "S": "Injury, poisoning and certain other consequences of external causes",
264
+ "T": "Injury, poisoning and certain other consequences of external causes",
265
+ "V": "External causes of morbidity",
266
+ "W": "External causes of morbidity",
267
+ "X": "External causes of morbidity",
268
+ "Y": "External causes of morbidity",
269
+ "Z": "Factors influencing health status and contact with health services",
270
+ }
271
+
272
+
273
+ def _get_description(code: str, count: int) -> str:
274
+ """Generate a brief description for a code (used in codes.parquet)."""
275
+ if code in ("MEDS_BIRTH",): return "Birth event (reserved)"
276
+ if code in ("MEDS_DEATH",): return "Death event (reserved)"
277
+ parts = code.split("//")
278
+ if len(parts) < 2: return f"Unknown code (n={count})"
279
+ domain, val = parts[0], "//".join(parts[1:])
280
+ if domain == "GENDER": return f"Patient sex = {val}"
281
+ if domain == "ORPHA": return f"Orphanet rare disease {val}"
282
+ if domain == "ICD10":
283
+ ch = ICD10_CHAPTERS.get(val[0], "Unknown chapter")
284
+ return f"ICD-10 {val} ({ch})"
285
+ if domain == "SIH": return f"SIH hospitalization {val}"
286
+ if domain == "Visit": return f"Visit type {val}"
287
+ if domain == "SIGTAP": return f"SIGTAP procedure {val}"
288
+ if domain == "APAC": return f"APAC orphan-drug authorization {val}"
289
+ if domain == "BPAI": return f"BPA-I outpatient procedure {val}"
290
+ if domain == "UF": return f"Residence UF {val}"
291
+ return f"{domain} code {val}"
292
+
293
+
294
+ def _get_parent_codes(code: str) -> list[str]:
295
+ """Return parent codes for ontology hierarchy (currently minimal)."""
296
+ parts = code.split("//")
297
+ if len(parts) < 2: return []
298
+ domain, val = parts[0], "//".join(parts[1:])
299
+ parents = []
300
+ if domain == "ICD10" and len(val) >= 3:
301
+ # ICD-10 chapter as parent
302
+ chapter = val[0]
303
+ if chapter in ICD10_CHAPTERS:
304
+ parents.append(f"ICD10//chapter_{chapter}")
305
+ # 3-char prefix as parent (e.g., E84.0 → E84)
306
+ if "." in val:
307
+ parents.append(f"ICD10//{val.split('.')[0]}")
308
+ elif len(val) > 3:
309
+ parents.append(f"ICD10//{val[:3]}")
310
+ if domain == "SIGTAP" and len(val) >= 4:
311
+ # 4-digit group as parent (SIGTAP 10-digit → 4-digit group)
312
+ parents.append(f"SIGTAP//group_{val[:4]}")
313
+ return parents
314
+
315
+
316
+ def load_meds_dataset(meds_dir: str) -> dict:
317
+ """Load a MEDS dataset back from parquet for inspection or downstream processing."""
318
+ import glob
319
+ shards = sorted(glob.glob(f"{meds_dir}/data/*.parquet"))
320
+ tables = [pq.read_table(p) for p in shards]
321
+ data = pa.concat_tables(tables) if tables else None
322
+ codes = pq.read_table(f"{meds_dir}/metadata/codes.parquet")
323
+ md = json.load(open(f"{meds_dir}/dataset_metadata.json"))
324
+ return {"data": data, "codes": codes, "metadata": md}
325
+
326
+
327
+ if __name__ == "__main__":
328
+ # Quick test on real patient data
329
+ logging.basicConfig(level=logging.INFO,
330
+ format="%(asctime)s %(levelname)s %(message)s")
331
+ PATIENTS = "/tmp/datasus_patient_trajectories_v2.json"
332
+ if os.path.exists(PATIENTS):
333
+ patients = json.load(open(PATIENTS))[:50] # 50 patients smoke test
334
+ md = export_to_meds(patients, "/tmp/meds_smoke_test")
335
+ print("\n=== smoke test result ===")
336
+ print(json.dumps(md, indent=2, default=str))
src/primekg_attention.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PrimeKG cross-attention — graph-RAG into the Diffusion Forcing denoiser.
2
+
3
+ Now uses REAL EDGES from raras-app/data/graph-ml/hetero_graph.json:
4
+ - disease → has_phenotype → phenotype (curated phenotype linkage)
5
+ - disease → associated_with → gene (causal gene evidence)
6
+ - gene → interacts_with → gene (PPI network)
7
+ - phenotype → is_a → phenotype (HPO ontology)
8
+
9
+ Ego-subgraph BFS:
10
+ 1. Start from disease node (ORPHA → PrimeKG index)
11
+ 2. 1-hop: pull connected phenotypes (top-K by edge weight or count)
12
+ 3. 1-hop: pull connected genes
13
+ 4. 2-hop: gene→gene neighbors (interacting partners)
14
+ 5. Concatenate fused embeddings of all selected nodes → cross-attn context
15
+
16
+ Falls back to cosine-similarity if graph not loaded.
17
+
18
+ White-space architecture (May 2026):
19
+ - EHRWorld, CLARITY, Time-Aware G-Transformer all skip KG conditioning
20
+ - PhenoKG/RareNet use KG for RETRIEVAL (rare disease diagnosis)
21
+ - We use it for GENERATION (counterfactual trajectory completion)
22
+ """
23
+ from __future__ import annotations
24
+ import os
25
+ import json
26
+ import logging
27
+ from functools import lru_cache
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ log = logging.getLogger("gemeo.cdf.kg")
35
+
36
+ # Try raras-app paths first (richer, including hetero_graph edges + node_texts)
37
+ RARAS_KG_DIR = "/Users/dimas/raras-app/data/graph-ml"
38
+ LOCAL_KG_DIR = os.path.join(
39
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
40
+
41
+
42
+ def _kg_path(name: str) -> str:
43
+ """Prefer raras-app path if available, fall back to local fp16."""
44
+ raras = os.path.join(RARAS_KG_DIR, name)
45
+ if os.path.exists(raras):
46
+ return raras
47
+ local = os.path.join(LOCAL_KG_DIR, name)
48
+ return local if os.path.exists(local) else None
49
+
50
+
51
+ @lru_cache(maxsize=1)
52
+ def load_kg(prefer_raras: bool = True) -> dict | None:
53
+ """Load PrimeKG: fused embeddings + node ids + edges + texts.
54
+
55
+ Returns dict:
56
+ emb : {kind: torch.Tensor(N, 3072)}
57
+ idx2id : {kind: {pos: id_str}}
58
+ id2idx : {kind: {id_str: pos}}
59
+ edges : {edge_type: {'src': [...], 'dst': [...]}}
60
+ adj : {edge_type: {src_idx: [dst_idx, ...]}} -- precomputed
61
+ texts : {kind: [str, ...]} -- aligned to position
62
+ num_nodes : {kind: int}
63
+ """
64
+ # Try raras-app full file first, then local fp16
65
+ emb_path = (os.path.join(RARAS_KG_DIR, "fused_embeddings.npz")
66
+ if prefer_raras and os.path.exists(os.path.join(RARAS_KG_DIR, "fused_embeddings.npz"))
67
+ else _kg_path("fused_embeddings_fp16.npz"))
68
+ if not emb_path or not os.path.exists(emb_path):
69
+ log.warning("PrimeKG fused embeddings not found")
70
+ return None
71
+
72
+ nids_path = _kg_path("node_ids.json")
73
+ graph_path = _kg_path("hetero_graph.json")
74
+ texts_path = _kg_path("node_texts.json")
75
+
76
+ fz = np.load(emb_path)
77
+ nids = json.load(open(nids_path)) if nids_path else {}
78
+ graph = json.load(open(graph_path)) if graph_path else {"edges": {}, "num_nodes": {}}
79
+ texts = json.load(open(texts_path)) if texts_path else {}
80
+
81
+ out = {"emb": {}, "id2idx": {}, "idx2id": {}, "edges": {}, "adj": {},
82
+ "texts": texts, "num_nodes": graph.get("num_nodes", {})}
83
+
84
+ for kind in ("disease", "phenotype", "gene"):
85
+ if kind in fz.files:
86
+ out["emb"][kind] = torch.from_numpy(fz[kind].astype(np.float32))
87
+ if kind in nids:
88
+ out["idx2id"][kind] = {int(k): v for k, v in nids[kind].items()}
89
+ out["id2idx"][kind] = {v: int(k) for k, v in nids[kind].items()}
90
+
91
+ # Build adjacency from edges
92
+ for edge_type, edata in graph.get("edges", {}).items():
93
+ adj = {}
94
+ srcs = edata.get("src", []) if isinstance(edata, dict) else []
95
+ dsts = edata.get("dst", []) if isinstance(edata, dict) else []
96
+ for s, d in zip(srcs, dsts):
97
+ adj.setdefault(int(s), []).append(int(d))
98
+ out["adj"][edge_type] = adj
99
+ out["edges"][edge_type] = edata
100
+
101
+ log.info(f" KG loaded from {emb_path}")
102
+ log.info(f" disease={out['emb'].get('disease', torch.empty(0)).shape}, "
103
+ f"phenotype={out['emb'].get('phenotype', torch.empty(0)).shape}, "
104
+ f"gene={out['emb'].get('gene', torch.empty(0)).shape}")
105
+ log.info(f" edges: {list(out['edges'].keys())}")
106
+ return out
107
+
108
+
109
+ def ego_subgraph_real(orpha_code: str, k_pheno: int = 16, k_gene: int = 16,
110
+ k_gene_2hop: int = 0, kg: dict | None = None) -> torch.Tensor:
111
+ """BFS ego-subgraph using REAL PrimeKG edges (not cosine similarity).
112
+
113
+ Returns concatenated embeddings (N, 3072) where:
114
+ - 1 disease node (the query)
115
+ - up to k_pheno phenotype nodes (direct edges)
116
+ - up to k_gene gene nodes (direct edges)
117
+ - up to k_gene_2hop gene-gene 2-hop neighbors
118
+
119
+ Falls back to cosine similarity if no edges available.
120
+ """
121
+ if kg is None:
122
+ kg = load_kg()
123
+ if kg is None or "disease" not in kg["emb"]:
124
+ return None
125
+
126
+ d_id = kg["id2idx"]["disease"].get(str(orpha_code))
127
+ if d_id is None:
128
+ return None
129
+
130
+ d_emb = kg["emb"]["disease"][d_id]
131
+ nodes = [d_emb.unsqueeze(0)]
132
+
133
+ # Phenotype neighbors (via disease__has_phenotype__phenotype)
134
+ adj = kg["adj"].get("disease__has_phenotype__phenotype", {})
135
+ pheno_neighbors = adj.get(d_id, [])
136
+ if pheno_neighbors and "phenotype" in kg["emb"]:
137
+ pheno_neighbors = pheno_neighbors[:k_pheno]
138
+ nodes.append(kg["emb"]["phenotype"][pheno_neighbors])
139
+ elif "phenotype" in kg["emb"]:
140
+ # Fallback: cosine similarity
141
+ pool = kg["emb"]["phenotype"]
142
+ sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
143
+ top = sim.topk(min(k_pheno, pool.size(0))).indices
144
+ nodes.append(pool[top])
145
+
146
+ # Gene neighbors (via disease__associated_with__gene)
147
+ g_adj = kg["adj"].get("disease__associated_with__gene", {})
148
+ gene_neighbors = g_adj.get(d_id, [])
149
+ if gene_neighbors and "gene" in kg["emb"]:
150
+ gene_neighbors = gene_neighbors[:k_gene]
151
+ nodes.append(kg["emb"]["gene"][gene_neighbors])
152
+
153
+ # 2-hop: gene-gene neighbors of the genes we just pulled
154
+ if k_gene_2hop > 0:
155
+ gg_adj = kg["adj"].get("gene__interacts_with__gene", {})
156
+ seen = set(gene_neighbors)
157
+ second_hop = []
158
+ for g in gene_neighbors:
159
+ for g2 in gg_adj.get(g, []):
160
+ if g2 not in seen:
161
+ second_hop.append(g2)
162
+ seen.add(g2)
163
+ if len(second_hop) >= k_gene_2hop: break
164
+ if len(second_hop) >= k_gene_2hop: break
165
+ if second_hop:
166
+ nodes.append(kg["emb"]["gene"][second_hop])
167
+ elif "gene" in kg["emb"]:
168
+ pool = kg["emb"]["gene"]
169
+ sim = F.cosine_similarity(d_emb.unsqueeze(0), pool, dim=-1)
170
+ top = sim.topk(min(k_gene, pool.size(0))).indices
171
+ nodes.append(pool[top])
172
+
173
+ return torch.cat(nodes, dim=0)
174
+
175
+
176
+ # Keep old API name for backward compat
177
+ ego_subgraph = ego_subgraph_real
178
+
179
+
180
+ class KGCrossAttention(nn.Module):
181
+ """Cross-attention from sequence (B, T, d_model) to KG ego (B, N, d_model)."""
182
+ def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.1):
183
+ super().__init__()
184
+ self.n_heads = n_heads
185
+ self.head_dim = d_model // n_heads
186
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
187
+ self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
188
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
189
+ self.norm_q = nn.LayerNorm(d_model)
190
+ self.norm_kv = nn.LayerNorm(d_model)
191
+ self.dropout = nn.Dropout(dropout)
192
+
193
+ def forward(self, x_seq: torch.Tensor, x_kg: torch.Tensor) -> torch.Tensor:
194
+ B, T, D = x_seq.shape
195
+ _, N, _ = x_kg.shape
196
+ q = self.q_proj(self.norm_q(x_seq))
197
+ kv = self.kv_proj(self.norm_kv(x_kg))
198
+ k, v = kv.chunk(2, dim=-1)
199
+ q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
200
+ k = k.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
201
+ v = v.reshape(B, N, self.n_heads, self.head_dim).transpose(1, 2)
202
+ out = F.scaled_dot_product_attention(
203
+ q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
204
+ out = out.transpose(1, 2).reshape(B, T, D)
205
+ return x_seq + self.dropout(self.out_proj(out))
206
+
207
+
208
+ class KGProjector(nn.Module):
209
+ """Project 3072-d KG embeddings to d_model with LayerNorm."""
210
+ def __init__(self, kg_dim: int, d_model: int):
211
+ super().__init__()
212
+ self.proj = nn.Sequential(
213
+ nn.Linear(kg_dim, d_model),
214
+ nn.GELU(),
215
+ nn.LayerNorm(d_model),
216
+ )
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ return self.proj(x)
220
+
221
+
222
+ def build_kg_batch(orpha_strings: list[str], d_model: int,
223
+ projector: KGProjector,
224
+ k_pheno: int = 16, k_gene: int = 16,
225
+ k_gene_2hop: int = 0) -> torch.Tensor:
226
+ """Build (B, N, d_model) batched KG context for a batch of patient ORPHAs.
227
+
228
+ Falls back to zero context for missing ORPHAs.
229
+ """
230
+ kg = load_kg()
231
+ if kg is None:
232
+ return torch.zeros(len(orpha_strings), 1, d_model,
233
+ device=next(projector.parameters()).device)
234
+ N = 1 + k_pheno + k_gene + k_gene_2hop
235
+ egos = []
236
+ for orpha in orpha_strings:
237
+ e = ego_subgraph_real(orpha, k_pheno, k_gene, k_gene_2hop, kg)
238
+ if e is None:
239
+ e = torch.zeros(N, kg["emb"]["disease"].size(-1))
240
+ elif e.size(0) < N:
241
+ pad = torch.zeros(N - e.size(0), e.size(-1))
242
+ e = torch.cat([e, pad], dim=0)
243
+ egos.append(e[:N])
244
+ egos = torch.stack(egos, dim=0)
245
+ return projector(egos.to(next(projector.parameters()).device))
246
+
247
+
248
+ def precompute_kg_for_dataset(orpha_codes: list[str], projector: KGProjector,
249
+ k_pheno: int = 16, k_gene: int = 16,
250
+ batch_size: int = 32) -> torch.Tensor:
251
+ """Pre-compute KG context for an entire dataset in batches.
252
+
253
+ Returns (N_patients, kg_nodes, d_model) tensor on projector device.
254
+ Saves to disk-cacheable format.
255
+ """
256
+ out = []
257
+ for i in range(0, len(orpha_codes), batch_size):
258
+ batch = orpha_codes[i:i + batch_size]
259
+ ctx = build_kg_batch(batch, projector.proj[0].out_features,
260
+ projector, k_pheno, k_gene)
261
+ out.append(ctx.cpu())
262
+ return torch.cat(out, dim=0)
src/sample.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
2
+
3
+ Diffusion Forcing flexibility — the same model handles:
4
+
5
+ AR mode:
6
+ Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
7
+ transformer but with per-token noise control.
8
+
9
+ Denoise mode (bidirectional):
10
+ Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
11
+
12
+ Counterfactual mode (the TTE primitive):
13
+ Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
14
+ generate. Condition on (cohort, intervention_action_id). Sample N times,
15
+ compare distributions of outcome tokens.
16
+
17
+ CFG (classifier-free guidance) wraps any mode:
18
+ logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
19
+
20
+ Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
21
+ distilled student model — implemented in distill.py.
22
+ """
23
+ from __future__ import annotations
24
+ import logging
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+
29
+ from .diffusion_forcing import CDFTransformer
30
+
31
+ log = logging.getLogger("gemeo.cdf.sample")
32
+
33
+
34
+ @torch.no_grad()
35
+ def sample_denoise(
36
+ model: CDFTransformer,
37
+ cond: torch.Tensor,
38
+ *,
39
+ seed_prefix: torch.Tensor | None = None,
40
+ observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
41
+ action: torch.Tensor | None = None,
42
+ gamma: float = 2.0,
43
+ n_steps: int = 32,
44
+ null_cond: int = 0,
45
+ schedule: str = "cosine",
46
+ ) -> torch.Tensor:
47
+ """Denoise-mode sampling: fully-masked sequence + iterative refinement.
48
+
49
+ Supports:
50
+ - seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
51
+ - observed_mask: arbitrary positions to clamp (counterfactual mode)
52
+ - CFG via (cond, null_cond) pair
53
+ """
54
+ cfg = model.cfg
55
+ device = cond.device
56
+ B = cond.size(0)
57
+ T = cfg.max_seq_len
58
+
59
+ # Init with MASK
60
+ x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
61
+ fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
62
+ if seed_prefix is not None:
63
+ L = seed_prefix.size(1)
64
+ x[:, :L] = seed_prefix
65
+ fixed_mask[:, :L] = True
66
+ if observed_mask is not None:
67
+ fixed_mask |= observed_mask
68
+
69
+ # Build noise schedule
70
+ if schedule == "cosine":
71
+ # smooth cosine from 1 -> 0
72
+ ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
73
+ else:
74
+ ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
75
+
76
+ null = torch.full_like(cond, null_cond)
77
+ null_action = (torch.full_like(action, cfg.n_latent_actions)
78
+ if action is not None and cfg.use_latent_action else None)
79
+
80
+ for k in range(n_steps):
81
+ # Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
82
+ sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
83
+ torch.full((B, T), ts[k].item(), device=device))
84
+ logits_c = model(x, sigma, cond, action)
85
+ if gamma > 0:
86
+ logits_n = model(x, sigma, null, null_action)
87
+ logits = (1 + gamma) * logits_c - gamma * logits_n
88
+ else:
89
+ logits = logits_c
90
+ logits[:, :, cfg.mask_token] = -1e9
91
+
92
+ probs = F.softmax(logits, dim=-1)
93
+ confs, preds = probs.max(dim=-1)
94
+
95
+ # Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
96
+ t_next = ts[k+1].item()
97
+ target_kept = int(round((1 - t_next) * T))
98
+ revealed = (x != cfg.mask_token) | fixed_mask
99
+ already = revealed.sum(dim=-1)
100
+ new_x = x.clone()
101
+ for b in range(B):
102
+ need = max(0, target_kept - int(already[b].item()))
103
+ if need == 0:
104
+ continue
105
+ confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
106
+ topi = confs_b.topk(need).indices
107
+ new_x[b, topi] = preds[b, topi]
108
+ x = new_x
109
+
110
+ # Final cleanup
111
+ mask_left = x == cfg.mask_token
112
+ if mask_left.any():
113
+ sigma_final = torch.zeros(B, T, device=device)
114
+ logits_c = model(x, sigma_final, cond, action)
115
+ if gamma > 0:
116
+ logits_n = model(x, sigma_final, null, null_action)
117
+ logits = (1 + gamma) * logits_c - gamma * logits_n
118
+ else:
119
+ logits = logits_c
120
+ logits[:, :, cfg.mask_token] = -1e9
121
+ preds = logits.argmax(-1)
122
+ x = torch.where(mask_left, preds, x)
123
+ return x
124
+
125
+
126
+ @torch.no_grad()
127
+ def sample_ar(
128
+ model: CDFTransformer,
129
+ cond: torch.Tensor,
130
+ prefix: torch.Tensor,
131
+ *,
132
+ action: torch.Tensor | None = None,
133
+ max_new: int = 50,
134
+ temperature: float = 1.0,
135
+ gamma: float = 0.0,
136
+ null_cond: int = 0,
137
+ ) -> torch.Tensor:
138
+ """AR-mode sampling: future tokens at sigma=1, past at sigma=0.
139
+
140
+ Faster than denoise mode when you only want to continue a prefix.
141
+ """
142
+ cfg = model.cfg
143
+ device = cond.device
144
+ B = cond.size(0)
145
+ x = prefix.clone().to(device)
146
+ if x.dim() == 1: x = x.unsqueeze(0)
147
+ null = torch.full_like(cond, null_cond)
148
+ null_action = (torch.full_like(action, cfg.n_latent_actions)
149
+ if action is not None and cfg.use_latent_action else None)
150
+
151
+ for _ in range(max_new):
152
+ T_now = x.size(1)
153
+ if T_now >= cfg.max_seq_len:
154
+ break
155
+ # Pad with MASK
156
+ x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
157
+ device=device, dtype=torch.long)], dim=1)
158
+ sigma = torch.zeros(B, T_now + 1, device=device)
159
+ sigma[:, -1] = 1.0
160
+ a_pad = None
161
+ if action is not None and cfg.use_latent_action:
162
+ a_pad = torch.cat([action[:, :T_now],
163
+ torch.full((B, 1), cfg.n_latent_actions,
164
+ device=device, dtype=torch.long)], dim=1)
165
+ logits = model(x_pad, sigma, cond, a_pad)
166
+ if gamma > 0:
167
+ logits_n = model(x_pad, sigma, null, null_action)
168
+ logits = (1 + gamma) * logits - gamma * logits_n
169
+ logits[:, :, cfg.mask_token] = -1e9
170
+ p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
171
+ nxt = torch.multinomial(p, 1)
172
+ x = torch.cat([x, nxt], dim=1)
173
+ return x
174
+
175
+
176
+ @torch.no_grad()
177
+ def counterfactual_rollout(
178
+ model: CDFTransformer,
179
+ seed_prefix: torch.Tensor,
180
+ treatment_cond: int,
181
+ untreated_cond: int,
182
+ *,
183
+ treatment_action: int | None = None,
184
+ untreated_action: int | None = None,
185
+ n_samples: int = 100,
186
+ gamma: float = 2.0,
187
+ n_steps: int = 32,
188
+ ) -> dict:
189
+ """Sample paired counterfactual trajectories under treatment vs no-treatment.
190
+
191
+ Two ways to specify the intervention:
192
+ - via cond id (cohort-level): treatment_cond / untreated_cond
193
+ - via latent action id (per-token): treatment_action / untreated_action
194
+ """
195
+ cfg = model.cfg
196
+ device = next(model.parameters()).device
197
+ seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
198
+ T = cfg.max_seq_len
199
+
200
+ cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
201
+ cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
202
+
203
+ action_tx = action_null = None
204
+ if cfg.use_latent_action:
205
+ action_tx = torch.full((n_samples, T),
206
+ treatment_action if treatment_action is not None
207
+ else cfg.n_latent_actions,
208
+ device=device, dtype=torch.long)
209
+ action_null = torch.full((n_samples, T),
210
+ untreated_action if untreated_action is not None
211
+ else cfg.n_latent_actions,
212
+ device=device, dtype=torch.long)
213
+
214
+ traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
215
+ action=action_tx, gamma=gamma, n_steps=n_steps)
216
+ traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
217
+ action=action_null, gamma=gamma, n_steps=n_steps)
218
+ return {
219
+ "traj_treated": traj_tx, "traj_untreated": traj_null,
220
+ "n": n_samples, "treatment_cond": treatment_cond,
221
+ "untreated_cond": untreated_cond, "gamma": gamma,
222
+ }
223
+
224
+
225
+ def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
226
+ if not target_ids:
227
+ return 0.0
228
+ target = torch.tensor(target_ids, device=traj.device)
229
+ has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
230
+ return has.float().mean().item()
src/wsd_scheduler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WSD (Warmup-Stable-Decay) LR scheduler — manual implementation.
2
+
3
+ Per MiniCPM (Hu et al. 2024) and the data-constrained scaling literature:
4
+ Phase 1 (warmup, 1-5% of total_steps): linear 0 → peak_lr
5
+ Phase 2 (stable, 60-80%): constant peak_lr
6
+ Phase 3 (decay, 10-25%): linear or 1/sqrt to peak_lr * 0.1
7
+
8
+ Beats cosine for:
9
+ - data-limited regimes (we can extend stable phase if loss still falls)
10
+ - continue-pretrain (sharp decay enables clean fine-tune handoff)
11
+ """
12
+ from __future__ import annotations
13
+ import math
14
+ import torch
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+
17
+
18
+ def wsd_lr_schedule(step: int, total_steps: int,
19
+ warmup_steps: int = 500,
20
+ stable_frac: float = 0.80,
21
+ decay_frac: float = 0.15,
22
+ min_lr_ratio: float = 0.1,
23
+ decay_type: str = "linear") -> float:
24
+ """Return LR multiplier in [min_lr_ratio, 1.0] for a given step."""
25
+ if step < warmup_steps:
26
+ return step / max(1, warmup_steps)
27
+ # remainder of steps after warmup
28
+ remaining = total_steps - warmup_steps
29
+ if remaining <= 0:
30
+ return 1.0
31
+ stable_steps = int(stable_frac * remaining)
32
+ decay_steps = int(decay_frac * remaining)
33
+ pos = step - warmup_steps
34
+ if pos < stable_steps:
35
+ return 1.0
36
+ decay_pos = pos - stable_steps
37
+ if decay_pos >= decay_steps:
38
+ return min_lr_ratio
39
+ progress = decay_pos / max(1, decay_steps)
40
+ if decay_type == "linear":
41
+ return 1.0 - (1.0 - min_lr_ratio) * progress
42
+ elif decay_type == "cosine":
43
+ return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
44
+ elif decay_type == "inv_sqrt":
45
+ return max(min_lr_ratio, 1.0 / math.sqrt(1 + progress * 10))
46
+ else:
47
+ raise ValueError(f"unknown decay_type: {decay_type}")
48
+
49
+
50
+ def get_wsd_scheduler(optimizer: torch.optim.Optimizer,
51
+ total_steps: int,
52
+ warmup_steps: int = 500,
53
+ stable_frac: float = 0.80,
54
+ decay_frac: float = 0.15,
55
+ min_lr_ratio: float = 0.1,
56
+ decay_type: str = "linear") -> LambdaLR:
57
+ """Build a LambdaLR scheduler with WSD schedule."""
58
+ def fn(step):
59
+ return wsd_lr_schedule(step, total_steps, warmup_steps,
60
+ stable_frac, decay_frac, min_lr_ratio, decay_type)
61
+ return LambdaLR(optimizer, lr_lambda=fn)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ # Visualize the schedule
66
+ total = 10000
67
+ warmup = 500
68
+ print(f"WSD schedule preview: total={total}, warmup={warmup}, stable=80%, decay=15%")
69
+ print(f" step lr_mult")
70
+ for s in [0, 250, 500, 1000, 5000, 8000, 8500, 9000, 9500, 9800, 9999]:
71
+ m = wsd_lr_schedule(s, total, warmup, 0.80, 0.15, 0.1, "linear")
72
+ print(f" {s:>5} {m:.4f}")