gemeo-twin-stack / src /gemeo /cwm /sensitivity.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Sensitivity analyses required for SOTA causal EHR paper (Nature Med tier).
Implements:
- E-value (VanderWeele 2017) for unmeasured confounding
- Negative control outcome / negative control exposure
- Tipping-point analysis (Liublinska & Rubin 2014)
- Distribution shift: site-holdout, temporal-holdout, intervention-shift
Reviewers in 2026 expect all four for ATE claims.
"""
from __future__ import annotations
import math
import logging
from dataclasses import dataclass
import numpy as np
log = logging.getLogger("gemeo.cwm.sens")
def e_value(rr: float) -> float:
"""E-value for an observed risk ratio (VanderWeele & Ding 2017).
Returns the minimum strength of an unmeasured confounder, on the
risk-ratio scale, required to fully explain away an observed effect.
"""
rr = max(rr, 1e-9)
if rr >= 1.0:
return rr + math.sqrt(rr * (rr - 1))
rr_inv = 1.0 / rr
return rr_inv + math.sqrt(rr_inv * (rr_inv - 1))
def e_value_ci(ci_low: float, ci_high: float) -> tuple[float, float]:
"""E-values for the bounds of a 95% CI of the risk ratio."""
bound = ci_low if ci_low >= 1 else ci_high
if bound >= 1:
return e_value(bound), e_value(ci_low if ci_low >= 1 else ci_high)
return e_value(ci_low), e_value(ci_high)
def rd_to_rr(p_treated: float, p_untreated: float) -> float:
"""Convert risk difference (p1 - p0) to risk ratio (p1/p0)."""
return p_treated / max(p_untreated, 1e-9)
@dataclass
class SensitivityReport:
ate: float
rr: float
e_value_point: float
e_value_ci: float
negative_control_ate: float | None
negative_control_ok: bool # True if NC ATE ~ 0
tipping_point: float | None # outcome rate shift that flips sig
interpretation: str
def negative_control_check(nc_ate: float, threshold: float = 0.02) -> bool:
"""A good negative control returns ATE close to 0 (within threshold)."""
return abs(nc_ate) < threshold
def tipping_point(observed_effect: float, ci_half_width: float,
null: float = 0.0, steps: int = 100) -> float:
"""How much would the outcome rate in the untreated arm need to shift to
nullify the effect? Simple linear extrapolation under assumption of
no positivity violation.
"""
if abs(observed_effect) <= ci_half_width:
return 0.0
return float(abs(observed_effect) - ci_half_width)
def assess(tte_result, neg_control=None) -> SensitivityReport:
"""Build a full sensitivity report from a TTEResult."""
rd = tte_result.ate
p1 = tte_result.outcome_treated
p0 = tte_result.outcome_untreated
rr = rd_to_rr(p1, p0)
ev_pt = e_value(rr)
# CI half-width as a crude proxy
half = (tte_result.ate_ci[1] - tte_result.ate_ci[0]) / 2.0
# E-value for the CI bound closer to null
rr_lo = rd_to_rr(max(p1 - half, 1e-9), p0 + half if p0 + half > 0 else p0)
ev_ci = e_value(rr_lo)
nc_ok = True
nc_ate = None
if neg_control is not None:
nc_ate = neg_control.ate
nc_ok = negative_control_check(nc_ate)
interp = []
if ev_pt < 1.5:
interp.append(f"WEAK: E={ev_pt:.2f} (unmeasured confounder of RR>={ev_pt:.2f} "
"would explain away the effect)")
elif ev_pt < 2.5:
interp.append(f"MODERATE: E={ev_pt:.2f}")
else:
interp.append(f"STRONG: E={ev_pt:.2f}")
if not nc_ok and nc_ate is not None:
interp.append(f"FAIL negative control: ATE_nc={nc_ate:.4f} (should be ~0)")
elif nc_ate is not None:
interp.append(f"PASS negative control: ATE_nc={nc_ate:.4f}")
tip = tipping_point(rd, half)
interp.append(f"tipping-point shift = {tip:.4f}")
return SensitivityReport(
ate=rd, rr=rr, e_value_point=ev_pt, e_value_ci=ev_ci,
negative_control_ate=nc_ate, negative_control_ok=nc_ok,
tipping_point=tip, interpretation="; ".join(interp),
)