--- tags: - causal-inference - time-series - foundation-model - temporal-causal - causal-judgment license: other license_name: proprietary license_link: LICENSE --- # TCPFN — Temporal Causal Prior-Data Fitted Networks A family of causal reasoning foundation models — predict effects, judge trustworthiness, operate zero-shot. Three checkpoints share one architecture (12-layer transformer, `embed_dim=512`, 8 heads, HL-Gaussian output head) and differ only in training-data distribution and curriculum. ## Pick by task | Task | Best checkpoint | Path | |------|-----------------|------| | General causal discovery (biology, cross-sectional, short-lag) | **v2.1** | `models/temporal/final.pt` | | Industrial / long-range discovery (12+ h lags, digester→paper machine etc.) | **v2.2** | `models/v2.2/final.pt` | | Effect estimation (CATE / PEHE) | **v3** | `models/v3/final.pt` | All three are zero-shot. Pick the one matching your task — specialisation beats generalist on every task we've measured. ## Shared contributions 1. **Temporal Token Design** — first PFN for temporal panel data. 2. **Causal Judgment Head** — learned reliability signals (null detection, regime classification, identifiability, mediation, confounding). 3. **Causal Regime Prior** — direct / confounded / mediated / feedback training structures. 4. **Self-Calibration** — auto-detects natural experiments in sensor data. 5. **End-to-End System** — discovery + estimation + judgment + RCA from one forward pass. ## Shared capabilities - **Causal Discovery** — pairwise interventional CATE with judgment-aware edge scoring, natural-experiment detection, continuous treatment, multi-lag estimation, asymmetry penalty. - **Effect Estimation** — temporal CATE trajectories with distributional output. - **Causal Judgment** — null-effect detection, regime classification (learned heuristics, not formal guarantees). - **Root Cause Analysis** — 8-method ensemble (AERCA, ESD, ProRCA, GCM noise, ICC, Shapley, counterfactual, chain tracing). --- ## v2.1 — default discovery model - 200K steps, curriculum-trained (Phase 1 CATE-only → Phase 2 +Null → Phase 3 Full). - Mixed prior: 40% CausalTimePrior + 30% base + 30% CausalFM. - Training window: `max_T_pre=50, max_T_post=30`. - Hardware: RTX 5090, ~4.1 h, 13.9 steps/s. ### Discovery benchmarks (14 datasets, 6 domains, zero-shot) - Sachs (11 proteins, biological): F1 0.412, **AUROC 0.725** (vs Granger 0.291 / 0.621) — **champion**. - Causal Rivers (environmental): F1 0.319, AUROC 0.955. - Tennessee Eastman (52 vars, industrial): F1 0.314, AUROC 0.904. - SWaT (51 vars, water treatment): F1 0.265, AUROC 0.859. - CauseMe NVAR-5 / NVAR-10: F1 0.571 / 0.439. - Highest default-threshold F1 on 6 of 14 datasets. - Hallucination FPR: 0.02–0.08 (down from 1.0 in v2.0). ### Training metrics (mean over steps 150K–200K) - EffectLoss ~2.9 | JudgmentLoss ~2.8 - NullF1 0.94 | NullAUROC 0.99 | NullBrier 0.04 | NullSep 0.86 - RegimeAcc 0.68 | RegimeMacroF1 0.48 ### Limitations - CATE estimation weak (PEHE 0.92) due to per-group Z-standardisation — **use v3 for estimation**. --- ## v2.2 — industrial / long-range specialist Built for 12+ hour causal lags in industrial control loops (digester → paper machine, reactor → downstream controller). Training window extended 4× and curriculum rebalanced to include null-effect batches in Phase 2. - 200K steps, BF16 mixed precision, `head_lr_scale=0.1` (decouples output-head learning from backbone to prevent late-stage drift collapse). - Training window: `max_T_pre=200, max_T_post=100, max_horizon=500` — supports lags up to ~16 h at 2-min sampling. - Manual NaN-skip with observability (saves first NaN-producing batch, aborts if skip rate ≥ threshold). - Hardware: RTX 5090, ~14.9 h. ### Discovery benchmarks (default threshold 0.5) Strong on industrial / multivariate temporal data — **use this** when lags exceed ~1 h or when data is genuinely time-series (not stitched cross-sectional). | Dataset | Default F1 | Best F1 | AUROC | |---------|-----------|---------|-------| | Tennessee Eastman | **0.512** | 0.545 | **0.972** | | SWaT | **0.463** | 0.552 | **0.945** | | CauseMe VAR-5 | 0.769 | 0.800 | 0.960 | | CauseMe NVAR-5 | 0.800 | 0.800 | 0.863 | | CauseMe VAR-10 | 0.488 | 0.643 | 0.812 | | CauseMe NVAR-10 | 0.634 | 0.634 | 0.759 | | CauseMe Lorenz96-10 | 0.484 | 0.638 | 0.699 | | Sachs | 0.174 | 0.308 | 0.565 | Granger and PCMCI collapse on industrial data — they over-predict (1897 edges on TE vs 38 true), giving F1 ~0.04. TCPFN v2.2 is the only method with usable precision + recall together. ### Estimation benchmarks - Overall PEHE 0.917 | ATE MAE 0.504 | trajectory correlation ≈ 0 — **use v3 for CATE**. ### Limitations - **Sachs regressed** vs v2.1 (AUROC 0.565 vs 0.725). Use v2.1 for cross-sectional biological graphs. - Estimation degraded — trades short-range precision for long-range reach (see scar-tissue entry L-33 in project docs). --- ## v3 — estimation champion (experimental) Tag: `3.0.0-exp-global-std`. Global standardisation fix for the per-group Z-score bias that caps v2.1/v2.2 estimation quality. - 200K steps. - **PEHE 0.72** (vs v2.1 0.92 and v2.2 0.92) — best of the three on CATE estimation. - Discovery regressed slightly as trade-off; not yet benchmarked across all 14 discovery datasets — **use v2.1 or v2.2 for discovery**. ### Limitations - Experimental tag — standardisation change not yet battle-tested beyond estimation. - Full benchmark matrix still pending. --- ## Usage ```python from tcpfn import TemporalCausalAnalyzer # Discovery on general data (biology, cross-sectional) analyzer = TemporalCausalAnalyzer(temporal_model="models/temporal/final.pt") # Industrial / long-range discovery (lags in hours) analyzer = TemporalCausalAnalyzer(temporal_model="models/v2.2/final.pt") # Effect estimation (CATE trajectories, PEHE-sensitive work) analyzer = TemporalCausalAnalyzer(temporal_model="models/v3/final.pt") report = analyzer.run("sensor_data.csv") print(report.edges) # causal graph with edge strengths and lags print(report.summary()) # human-readable summary result = analyzer.explain_event( data_path="sensor_data.csv", target_var="temperature_sensor", event_time="2025-11-15 14:15", ) print(result.summary()) # ranked root causes + causal chains ``` ## Cross-cutting limitations - Regime classification is noisy (~0.68 accuracy, high eval variance). Judgment heads are **learned heuristics**, not formal guarantees. - Low-dim cross-sectional data stitched into pseudo-timeseries is out-of-distribution for v2.2 and v3; use v2.1. - v3 has not yet been run on the full discovery benchmark suite. ## Paper Stalupula et al., "Temporal Causal Prior-Data Fitted Networks for Panel Data with Learned Reliability Signals"