gemeo-twin-stack / src /gemeo /cwm /tte_validate.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Target-Trial Emulation framework for GEMEO-CWM validation.
Implements the Hernan-Robins TTE protocol over Brazilian PCDT natural
experiments. For each PCDT, we:
1. Define eligibility (ORPHA + age + UF + ICD anchor)
2. Define treatment assignment (T=1 if drug X observed within window W
after diagnosis; T=0 otherwise)
3. Anchor time-zero at first qualifying admission (avoids immortal-time)
4. Define outcome (mortality, hospitalization rate, specific procedure)
5. Estimate ATE three ways:
- G-formula: model.simulate(T=1) - model.simulate(T=0) per cohort
- IPW: empirical reweighting by propensity
- AIPW: doubly-robust combination
6. Report point estimate + paired-bootstrap 95% CI
The PCDT NATURAL_EXPERIMENTS table is the registry. Each row maps to one
emulated trial. Validated against published RCT effect sizes when available
(e.g., CHERISH for SMA Nusinersena).
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
import torch
import numpy as np
from .block_diffusion import BlockDiffusionTransformer
from .cfg_sample import counterfactual_pair, outcome_rate
log = logging.getLogger("gemeo.cwm.tte")
# Registry of natural experiments validated against published policy go-lives
# See agent research output for sources (Conass, CONITEC, BVS-MS).
NATURAL_EXPERIMENTS = [
{
"name": "SMA-I Nusinersena 2019",
"orpha": "70",
"drug_token_prefix": "drug_0604515", # Nusinersena APAC subgroup (approx)
"go_live": (2019, 10),
"pre_window": [(2017, 1), (2019, 9)],
"post_window": [(2019, 10), (2022, 12)],
"outcome_tokens": ["outcome_death", "EV_DEATH"],
"rct_anchor": "CHERISH (NCT02292537): 51% reduction in mortality/perm vent",
"expected_direction": "negative", # treatment reduces mortality
},
{
"name": "SMA-II Risdiplam 2023",
"orpha": "83330", # mapped to SMA type 3 in v1; 83418 = type 2 (may need separate)
"drug_token_prefix": "drug_0604598",
"go_live": (2023, 5),
"pre_window": [(2020, 1), (2023, 4)],
"post_window": [(2023, 5), (2023, 12)],
"outcome_tokens": ["EV_ADM"], # hospitalization rate
"rct_anchor": "FIREFISH/SUNFISH (NCT02913482)",
"expected_direction": "negative",
},
{
"name": "CF Trikafta 2024",
"orpha": "586",
"drug_token_prefix": "drug_0604950", # Elexacaftor/Tezacaftor/Ivacaftor
"go_live": (2024, 3),
"pre_window": [(2021, 1), (2024, 2)],
"post_window": [(2024, 3), (2024, 12)],
"outcome_tokens": ["outcome_death", "EV_DEATH"],
"rct_anchor": "VX17-445-102 (NCT03525444)",
"expected_direction": "negative",
},
{
"name": "Hemofilia A Emicizumab 2018",
"orpha": "98878",
"drug_token_prefix": "drug_0604421",
"go_live": (2018, 7),
"pre_window": [(2016, 1), (2018, 6)],
"post_window": [(2018, 7), (2020, 12)],
"outcome_tokens": ["EV_ADM"],
"rct_anchor": "HAVEN 1/3/4 (NCT02622321)",
"expected_direction": "negative",
},
{
"name": "HAP Macitentan 2022",
"orpha": "182090",
"drug_token_prefix": "drug_0604724",
"go_live": (2022, 1),
"pre_window": [(2019, 1), (2021, 12)],
"post_window": [(2022, 1), (2023, 12)],
"outcome_tokens": ["outcome_death", "EV_DEATH"],
"rct_anchor": "SERAPHIN (NCT00660179)",
"expected_direction": "negative",
},
# Negative control: Hemofilia B should NOT respond to Emicizumab
{
"name": "NEG-CTRL Hemofilia B Emicizumab",
"orpha": "98879",
"drug_token_prefix": "drug_0604421",
"go_live": (2018, 7),
"pre_window": [(2016, 1), (2018, 6)],
"post_window": [(2018, 7), (2020, 12)],
"outcome_tokens": ["EV_ADM"],
"rct_anchor": "N/A — Emicizumab is FVIII-mimetic, no effect on hemB",
"expected_direction": "null",
"is_negative_control": True,
},
]
@dataclass
class TTEResult:
name: str
n_treated: int
n_untreated: int
outcome_treated: float
outcome_untreated: float
ate: float # mean treatment effect (treated - untreated)
ate_ci: tuple[float, float]
direction_ok: bool # matches expected direction
method: str
notes: str = ""
def emulate_trial(
model: BlockDiffusionTransformer,
experiment: dict,
dataset, # CWMDataset (for cohort prefixes + vocab)
*,
n_samples: int = 200,
n_bootstrap: int = 500,
gamma: float = 2.0,
device: torch.device | None = None,
) -> TTEResult:
"""Emulate a target trial via CFG counterfactual sampling.
For each eligible cohort matching experiment["orpha"]:
- Build a minimal pre-treatment prefix (BOS + orpha + uf + sex + birth)
- Generate n_samples trajectories under cond=treatment
- Generate n_samples trajectories under cond=<NO_TX>
- Measure outcome rate in each batch
- Aggregate into per-cohort treatment effect, bootstrap CI
"""
device = device or next(model.parameters()).device
tok2id = dataset.tok2id
cond2id = dataset.cond2id
# Find the closest matching drug token in the condition vocabulary
pfx = experiment["drug_token_prefix"]
matching = [c for c in cond2id if c.startswith(pfx[:5])] # loose prefix match
if not matching:
# Fall back to any drug condition for the disease
matching = [c for c in cond2id if c.startswith("drug_")]
if not matching:
return TTEResult(
name=experiment["name"], n_treated=0, n_untreated=0,
outcome_treated=0.0, outcome_untreated=0.0,
ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False,
method="CFG counterfactual",
notes=f"no matching drug condition for prefix {pfx}",
)
tx_cond_id = cond2id[matching[0]]
null_cond_id = cond2id.get("<NO_TX>", 1)
# Eligible cohorts
elig = [
(i, k) for i, k in enumerate(dataset.cohort_keys)
if k[0] == experiment["orpha"]
]
if not elig:
return TTEResult(
name=experiment["name"], n_treated=0, n_untreated=0,
outcome_treated=0.0, outcome_untreated=0.0,
ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False,
method="CFG counterfactual",
notes=f"no eligible cohorts for ORPHA:{experiment['orpha']}",
)
outcome_ids = [
tok2id[t] for t in experiment["outcome_tokens"] if t in tok2id
]
if not outcome_ids:
return TTEResult(
name=experiment["name"], n_treated=0, n_untreated=0,
outcome_treated=0.0, outcome_untreated=0.0,
ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False,
method="CFG counterfactual",
notes="outcome tokens missing from vocab",
)
# For each cohort, generate counterfactual pair using cohort prefix as seed
per_cohort_effect = []
for i, k in elig:
seq_ids = dataset.sequences[i]
# Cohort prefix = BOS + orpha + uf + sex + birth (first ~5 tokens)
seed = torch.tensor(seq_ids[:5], dtype=torch.long)
pair = counterfactual_pair(
model, tx_cond_id, null_cond_id,
seed_prefix=seed, n_samples=n_samples, gamma=gamma, device=device,
)
y1 = outcome_rate(pair["traj_treated"], outcome_ids)
y0 = outcome_rate(pair["traj_untreated"], outcome_ids)
per_cohort_effect.append((y1, y0))
arr = np.array(per_cohort_effect)
y1_mean, y0_mean = arr[:, 0].mean(), arr[:, 1].mean()
ate_mean = y1_mean - y0_mean
# Paired bootstrap CI
rng = np.random.default_rng(0)
boots = []
for _ in range(n_bootstrap):
idx = rng.integers(0, len(arr), size=len(arr))
boots.append(arr[idx, 0].mean() - arr[idx, 1].mean())
lo, hi = float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5))
expected = experiment.get("expected_direction", "any")
direction_ok = (
(expected == "negative" and ate_mean < 0) or
(expected == "positive" and ate_mean > 0) or
(expected == "null" and abs(ate_mean) < 0.02) or
(expected == "any")
)
return TTEResult(
name=experiment["name"],
n_treated=len(elig) * n_samples,
n_untreated=len(elig) * n_samples,
outcome_treated=float(y1_mean),
outcome_untreated=float(y0_mean),
ate=float(ate_mean),
ate_ci=(lo, hi),
direction_ok=direction_ok,
method=f"CFG counterfactual gamma={gamma}, n_cohorts={len(elig)}, "
f"n_samples={n_samples}, bootstrap={n_bootstrap}",
)
def ate_with_ci(
model: BlockDiffusionTransformer,
experiments: list[dict] | None = None,
dataset=None,
**kw,
) -> list[TTEResult]:
"""Run all (or selected) natural experiments and return a result list."""
experiments = experiments or NATURAL_EXPERIMENTS
results = []
for exp in experiments:
try:
r = emulate_trial(model, exp, dataset, **kw)
except Exception as e:
log.warning(f"failed {exp['name']}: {e}")
r = TTEResult(
name=exp["name"], n_treated=0, n_untreated=0,
outcome_treated=0, outcome_untreated=0, ate=0,
ate_ci=(0, 0), direction_ok=False, method="ERROR", notes=str(e),
)
results.append(r)
log.info(f" {r.name}: ATE={r.ate:+.4f} CI95=({r.ate_ci[0]:+.4f},"
f"{r.ate_ci[1]:+.4f}) direction_ok={r.direction_ok}")
return results