| """Sensitivity analyses required for SOTA causal EHR paper (Nature Med tier). |
| |
| Implements: |
| - E-value (VanderWeele 2017) for unmeasured confounding |
| - Negative control outcome / negative control exposure |
| - Tipping-point analysis (Liublinska & Rubin 2014) |
| - Distribution shift: site-holdout, temporal-holdout, intervention-shift |
| |
| Reviewers in 2026 expect all four for ATE claims. |
| """ |
| from __future__ import annotations |
| import math |
| import logging |
| from dataclasses import dataclass |
|
|
| import numpy as np |
|
|
| log = logging.getLogger("gemeo.cwm.sens") |
|
|
|
|
| def e_value(rr: float) -> float: |
| """E-value for an observed risk ratio (VanderWeele & Ding 2017). |
| |
| Returns the minimum strength of an unmeasured confounder, on the |
| risk-ratio scale, required to fully explain away an observed effect. |
| """ |
| rr = max(rr, 1e-9) |
| if rr >= 1.0: |
| return rr + math.sqrt(rr * (rr - 1)) |
| rr_inv = 1.0 / rr |
| return rr_inv + math.sqrt(rr_inv * (rr_inv - 1)) |
|
|
|
|
| def e_value_ci(ci_low: float, ci_high: float) -> tuple[float, float]: |
| """E-values for the bounds of a 95% CI of the risk ratio.""" |
| bound = ci_low if ci_low >= 1 else ci_high |
| if bound >= 1: |
| return e_value(bound), e_value(ci_low if ci_low >= 1 else ci_high) |
| return e_value(ci_low), e_value(ci_high) |
|
|
|
|
| def rd_to_rr(p_treated: float, p_untreated: float) -> float: |
| """Convert risk difference (p1 - p0) to risk ratio (p1/p0).""" |
| return p_treated / max(p_untreated, 1e-9) |
|
|
|
|
| @dataclass |
| class SensitivityReport: |
| ate: float |
| rr: float |
| e_value_point: float |
| e_value_ci: float |
| negative_control_ate: float | None |
| negative_control_ok: bool |
| tipping_point: float | None |
| interpretation: str |
|
|
|
|
| def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool: |
| """A good negative control returns ATE close to 0 (within threshold).""" |
| return abs(nc_ate) < threshold |
|
|
|
|
| def tipping_point(observed_effect: float, ci_half_width: float, |
| null: float = 0.0, steps: int = 100) -> float: |
| """How much would the outcome rate in the untreated arm need to shift to |
| nullify the effect? Simple linear extrapolation under assumption of |
| no positivity violation. |
| """ |
| if abs(observed_effect) <= ci_half_width: |
| return 0.0 |
| return float(abs(observed_effect) - ci_half_width) |
|
|
|
|
| def assess(tte_result, neg_control=None) -> SensitivityReport: |
| """Build a full sensitivity report from a TTEResult.""" |
| rd = tte_result.ate |
| p1 = tte_result.outcome_treated |
| p0 = tte_result.outcome_untreated |
| rr = rd_to_rr(p1, p0) |
| ev_pt = e_value(rr) |
| |
| half = (tte_result.ate_ci[1] - tte_result.ate_ci[0]) / 2.0 |
| |
| rr_lo = rd_to_rr(max(p1 - half, 1e-9), p0 + half if p0 + half > 0 else p0) |
| ev_ci = e_value(rr_lo) |
| nc_ok = True |
| nc_ate = None |
| if neg_control is not None: |
| nc_ate = neg_control.ate |
| nc_ok = negative_control_check(nc_ate) |
|
|
| interp = [] |
| if ev_pt < 1.5: |
| interp.append(f"WEAK: E={ev_pt:.2f} (unmeasured confounder of RR>={ev_pt:.2f} " |
| "would explain away the effect)") |
| elif ev_pt < 2.5: |
| interp.append(f"MODERATE: E={ev_pt:.2f}") |
| else: |
| interp.append(f"STRONG: E={ev_pt:.2f}") |
|
|
| if not nc_ok and nc_ate is not None: |
| interp.append(f"FAIL negative control: ATE_nc={nc_ate:.4f} (should be ~0)") |
| elif nc_ate is not None: |
| interp.append(f"PASS negative control: ATE_nc={nc_ate:.4f}") |
|
|
| tip = tipping_point(rd, half) |
| interp.append(f"tipping-point shift = {tip:.4f}") |
|
|
| return SensitivityReport( |
| ate=rd, rr=rr, e_value_point=ev_pt, e_value_ci=ev_ci, |
| negative_control_ate=nc_ate, negative_control_ok=nc_ok, |
| tipping_point=tip, interpretation="; ".join(interp), |
| ) |
|
|