| """Target-Trial Emulation framework for GEMEO-CWM validation. |
| |
| Implements the Hernan-Robins TTE protocol over Brazilian PCDT natural |
| experiments. For each PCDT, we: |
| |
| 1. Define eligibility (ORPHA + age + UF + ICD anchor) |
| 2. Define treatment assignment (T=1 if drug X observed within window W |
| after diagnosis; T=0 otherwise) |
| 3. Anchor time-zero at first qualifying admission (avoids immortal-time) |
| 4. Define outcome (mortality, hospitalization rate, specific procedure) |
| 5. Estimate ATE three ways: |
| - G-formula: model.simulate(T=1) - model.simulate(T=0) per cohort |
| - IPW: empirical reweighting by propensity |
| - AIPW: doubly-robust combination |
| 6. Report point estimate + paired-bootstrap 95% CI |
| |
| The PCDT NATURAL_EXPERIMENTS table is the registry. Each row maps to one |
| emulated trial. Validated against published RCT effect sizes when available |
| (e.g., CHERISH for SMA Nusinersena). |
| """ |
| from __future__ import annotations |
| import logging |
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import numpy as np |
|
|
| from .block_diffusion import BlockDiffusionTransformer |
| from .cfg_sample import counterfactual_pair, outcome_rate |
|
|
| log = logging.getLogger("gemeo.cwm.tte") |
|
|
|
|
| |
| |
| NATURAL_EXPERIMENTS = [ |
| { |
| "name": "SMA-I Nusinersena 2019", |
| "orpha": "70", |
| "drug_token_prefix": "drug_0604515", |
| "go_live": (2019, 10), |
| "pre_window": [(2017, 1), (2019, 9)], |
| "post_window": [(2019, 10), (2022, 12)], |
| "outcome_tokens": ["outcome_death", "EV_DEATH"], |
| "rct_anchor": "CHERISH (NCT02292537): 51% reduction in mortality/perm vent", |
| "expected_direction": "negative", |
| }, |
| { |
| "name": "SMA-II Risdiplam 2023", |
| "orpha": "83330", |
| "drug_token_prefix": "drug_0604598", |
| "go_live": (2023, 5), |
| "pre_window": [(2020, 1), (2023, 4)], |
| "post_window": [(2023, 5), (2023, 12)], |
| "outcome_tokens": ["EV_ADM"], |
| "rct_anchor": "FIREFISH/SUNFISH (NCT02913482)", |
| "expected_direction": "negative", |
| }, |
| { |
| "name": "CF Trikafta 2024", |
| "orpha": "586", |
| "drug_token_prefix": "drug_0604950", |
| "go_live": (2024, 3), |
| "pre_window": [(2021, 1), (2024, 2)], |
| "post_window": [(2024, 3), (2024, 12)], |
| "outcome_tokens": ["outcome_death", "EV_DEATH"], |
| "rct_anchor": "VX17-445-102 (NCT03525444)", |
| "expected_direction": "negative", |
| }, |
| { |
| "name": "Hemofilia A Emicizumab 2018", |
| "orpha": "98878", |
| "drug_token_prefix": "drug_0604421", |
| "go_live": (2018, 7), |
| "pre_window": [(2016, 1), (2018, 6)], |
| "post_window": [(2018, 7), (2020, 12)], |
| "outcome_tokens": ["EV_ADM"], |
| "rct_anchor": "HAVEN 1/3/4 (NCT02622321)", |
| "expected_direction": "negative", |
| }, |
| { |
| "name": "HAP Macitentan 2022", |
| "orpha": "182090", |
| "drug_token_prefix": "drug_0604724", |
| "go_live": (2022, 1), |
| "pre_window": [(2019, 1), (2021, 12)], |
| "post_window": [(2022, 1), (2023, 12)], |
| "outcome_tokens": ["outcome_death", "EV_DEATH"], |
| "rct_anchor": "SERAPHIN (NCT00660179)", |
| "expected_direction": "negative", |
| }, |
| |
| { |
| "name": "NEG-CTRL Hemofilia B Emicizumab", |
| "orpha": "98879", |
| "drug_token_prefix": "drug_0604421", |
| "go_live": (2018, 7), |
| "pre_window": [(2016, 1), (2018, 6)], |
| "post_window": [(2018, 7), (2020, 12)], |
| "outcome_tokens": ["EV_ADM"], |
| "rct_anchor": "N/A — Emicizumab is FVIII-mimetic, no effect on hemB", |
| "expected_direction": "null", |
| "is_negative_control": True, |
| }, |
| ] |
|
|
|
|
| @dataclass |
| class TTEResult: |
| name: str |
| n_treated: int |
| n_untreated: int |
| outcome_treated: float |
| outcome_untreated: float |
| ate: float |
| ate_ci: tuple[float, float] |
| direction_ok: bool |
| method: str |
| notes: str = "" |
|
|
|
|
| def emulate_trial( |
| model: BlockDiffusionTransformer, |
| experiment: dict, |
| dataset, |
| *, |
| n_samples: int = 200, |
| n_bootstrap: int = 500, |
| gamma: float = 2.0, |
| device: torch.device | None = None, |
| ) -> TTEResult: |
| """Emulate a target trial via CFG counterfactual sampling. |
| |
| For each eligible cohort matching experiment["orpha"]: |
| - Build a minimal pre-treatment prefix (BOS + orpha + uf + sex + birth) |
| - Generate n_samples trajectories under cond=treatment |
| - Generate n_samples trajectories under cond=<NO_TX> |
| - Measure outcome rate in each batch |
| - Aggregate into per-cohort treatment effect, bootstrap CI |
| """ |
| device = device or next(model.parameters()).device |
| tok2id = dataset.tok2id |
| cond2id = dataset.cond2id |
|
|
| |
| pfx = experiment["drug_token_prefix"] |
| matching = [c for c in cond2id if c.startswith(pfx[:5])] |
| if not matching: |
| |
| matching = [c for c in cond2id if c.startswith("drug_")] |
| if not matching: |
| return TTEResult( |
| name=experiment["name"], n_treated=0, n_untreated=0, |
| outcome_treated=0.0, outcome_untreated=0.0, |
| ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False, |
| method="CFG counterfactual", |
| notes=f"no matching drug condition for prefix {pfx}", |
| ) |
| tx_cond_id = cond2id[matching[0]] |
| null_cond_id = cond2id.get("<NO_TX>", 1) |
|
|
| |
| elig = [ |
| (i, k) for i, k in enumerate(dataset.cohort_keys) |
| if k[0] == experiment["orpha"] |
| ] |
| if not elig: |
| return TTEResult( |
| name=experiment["name"], n_treated=0, n_untreated=0, |
| outcome_treated=0.0, outcome_untreated=0.0, |
| ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False, |
| method="CFG counterfactual", |
| notes=f"no eligible cohorts for ORPHA:{experiment['orpha']}", |
| ) |
|
|
| outcome_ids = [ |
| tok2id[t] for t in experiment["outcome_tokens"] if t in tok2id |
| ] |
| if not outcome_ids: |
| return TTEResult( |
| name=experiment["name"], n_treated=0, n_untreated=0, |
| outcome_treated=0.0, outcome_untreated=0.0, |
| ate=0.0, ate_ci=(0.0, 0.0), direction_ok=False, |
| method="CFG counterfactual", |
| notes="outcome tokens missing from vocab", |
| ) |
|
|
| |
| per_cohort_effect = [] |
| for i, k in elig: |
| seq_ids = dataset.sequences[i] |
| |
| seed = torch.tensor(seq_ids[:5], dtype=torch.long) |
| pair = counterfactual_pair( |
| model, tx_cond_id, null_cond_id, |
| seed_prefix=seed, n_samples=n_samples, gamma=gamma, device=device, |
| ) |
| y1 = outcome_rate(pair["traj_treated"], outcome_ids) |
| y0 = outcome_rate(pair["traj_untreated"], outcome_ids) |
| per_cohort_effect.append((y1, y0)) |
|
|
| arr = np.array(per_cohort_effect) |
| y1_mean, y0_mean = arr[:, 0].mean(), arr[:, 1].mean() |
| ate_mean = y1_mean - y0_mean |
|
|
| |
| rng = np.random.default_rng(0) |
| boots = [] |
| for _ in range(n_bootstrap): |
| idx = rng.integers(0, len(arr), size=len(arr)) |
| boots.append(arr[idx, 0].mean() - arr[idx, 1].mean()) |
| lo, hi = float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5)) |
|
|
| expected = experiment.get("expected_direction", "any") |
| direction_ok = ( |
| (expected == "negative" and ate_mean < 0) or |
| (expected == "positive" and ate_mean > 0) or |
| (expected == "null" and abs(ate_mean) < 0.02) or |
| (expected == "any") |
| ) |
|
|
| return TTEResult( |
| name=experiment["name"], |
| n_treated=len(elig) * n_samples, |
| n_untreated=len(elig) * n_samples, |
| outcome_treated=float(y1_mean), |
| outcome_untreated=float(y0_mean), |
| ate=float(ate_mean), |
| ate_ci=(lo, hi), |
| direction_ok=direction_ok, |
| method=f"CFG counterfactual gamma={gamma}, n_cohorts={len(elig)}, " |
| f"n_samples={n_samples}, bootstrap={n_bootstrap}", |
| ) |
|
|
|
|
| def ate_with_ci( |
| model: BlockDiffusionTransformer, |
| experiments: list[dict] | None = None, |
| dataset=None, |
| **kw, |
| ) -> list[TTEResult]: |
| """Run all (or selected) natural experiments and return a result list.""" |
| experiments = experiments or NATURAL_EXPERIMENTS |
| results = [] |
| for exp in experiments: |
| try: |
| r = emulate_trial(model, exp, dataset, **kw) |
| except Exception as e: |
| log.warning(f"failed {exp['name']}: {e}") |
| r = TTEResult( |
| name=exp["name"], n_treated=0, n_untreated=0, |
| outcome_treated=0, outcome_untreated=0, ate=0, |
| ate_ci=(0, 0), direction_ok=False, method="ERROR", notes=str(e), |
| ) |
| results.append(r) |
| log.info(f" {r.name}: ATE={r.ate:+.4f} CI95=({r.ate_ci[0]:+.4f}," |
| f"{r.ate_ci[1]:+.4f}) direction_ok={r.direction_ok}") |
| return results |
|
|