tcpfn / README.md
mrshravan's picture
Upload README.md with huggingface_hub
7b1958c verified
---
tags:
- causal-inference
- time-series
- foundation-model
- temporal-causal
- causal-judgment
license: other
license_name: proprietary
license_link: LICENSE
---
# TCPFN β€” Temporal Causal Prior-Data Fitted Networks
A family of causal reasoning foundation models β€” predict effects, judge trustworthiness, operate zero-shot. Three checkpoints share one architecture (12-layer transformer, `embed_dim=512`, 8 heads, HL-Gaussian output head) and differ only in training-data distribution and curriculum.
## Pick by task
| Task | Best checkpoint | Path |
|------|-----------------|------|
| General causal discovery (biology, cross-sectional, short-lag) | **v2.1** | `models/temporal/final.pt` |
| Industrial / long-range discovery (12+ h lags, digester→paper machine etc.) | **v2.2** | `models/v2.2/final.pt` |
| Effect estimation (CATE / PEHE) | **v3** | `models/v3/final.pt` |
All three are zero-shot. Pick the one matching your task β€” specialisation beats generalist on every task we've measured.
## Shared contributions
1. **Temporal Token Design** β€” first PFN for temporal panel data.
2. **Causal Judgment Head** β€” learned reliability signals (null detection, regime classification, identifiability, mediation, confounding).
3. **Causal Regime Prior** β€” direct / confounded / mediated / feedback training structures.
4. **Self-Calibration** β€” auto-detects natural experiments in sensor data.
5. **End-to-End System** β€” discovery + estimation + judgment + RCA from one forward pass.
## Shared capabilities
- **Causal Discovery** β€” pairwise interventional CATE with judgment-aware edge scoring, natural-experiment detection, continuous treatment, multi-lag estimation, asymmetry penalty.
- **Effect Estimation** β€” temporal CATE trajectories with distributional output.
- **Causal Judgment** β€” null-effect detection, regime classification (learned heuristics, not formal guarantees).
- **Root Cause Analysis** β€” 8-method ensemble (AERCA, ESD, ProRCA, GCM noise, ICC, Shapley, counterfactual, chain tracing).
---
## v2.1 β€” default discovery model
- 200K steps, curriculum-trained (Phase 1 CATE-only β†’ Phase 2 +Null β†’ Phase 3 Full).
- Mixed prior: 40% CausalTimePrior + 30% base + 30% CausalFM.
- Training window: `max_T_pre=50, max_T_post=30`.
- Hardware: RTX 5090, ~4.1 h, 13.9 steps/s.
### Discovery benchmarks (14 datasets, 6 domains, zero-shot)
- Sachs (11 proteins, biological): F1 0.412, **AUROC 0.725** (vs Granger 0.291 / 0.621) β€” **champion**.
- Causal Rivers (environmental): F1 0.319, AUROC 0.955.
- Tennessee Eastman (52 vars, industrial): F1 0.314, AUROC 0.904.
- SWaT (51 vars, water treatment): F1 0.265, AUROC 0.859.
- CauseMe NVAR-5 / NVAR-10: F1 0.571 / 0.439.
- Highest default-threshold F1 on 6 of 14 datasets.
- Hallucination FPR: 0.02–0.08 (down from 1.0 in v2.0).
### Training metrics (mean over steps 150K–200K)
- EffectLoss ~2.9 | JudgmentLoss ~2.8
- NullF1 0.94 | NullAUROC 0.99 | NullBrier 0.04 | NullSep 0.86
- RegimeAcc 0.68 | RegimeMacroF1 0.48
### Limitations
- CATE estimation weak (PEHE 0.92) due to per-group Z-standardisation β€” **use v3 for estimation**.
---
## v2.2 β€” industrial / long-range specialist
Built for 12+ hour causal lags in industrial control loops (digester β†’ paper machine, reactor β†’ downstream controller). Training window extended 4Γ— and curriculum rebalanced to include null-effect batches in Phase 2.
- 200K steps, BF16 mixed precision, `head_lr_scale=0.1` (decouples output-head learning from backbone to prevent late-stage drift collapse).
- Training window: `max_T_pre=200, max_T_post=100, max_horizon=500` β€” supports lags up to ~16 h at 2-min sampling.
- Manual NaN-skip with observability (saves first NaN-producing batch, aborts if skip rate β‰₯ threshold).
- Hardware: RTX 5090, ~14.9 h.
### Discovery benchmarks (default threshold 0.5)
Strong on industrial / multivariate temporal data β€” **use this** when lags exceed ~1 h or when data is genuinely time-series (not stitched cross-sectional).
| Dataset | Default F1 | Best F1 | AUROC |
|---------|-----------|---------|-------|
| Tennessee Eastman | **0.512** | 0.545 | **0.972** |
| SWaT | **0.463** | 0.552 | **0.945** |
| CauseMe VAR-5 | 0.769 | 0.800 | 0.960 |
| CauseMe NVAR-5 | 0.800 | 0.800 | 0.863 |
| CauseMe VAR-10 | 0.488 | 0.643 | 0.812 |
| CauseMe NVAR-10 | 0.634 | 0.634 | 0.759 |
| CauseMe Lorenz96-10 | 0.484 | 0.638 | 0.699 |
| Sachs | 0.174 | 0.308 | 0.565 |
Granger and PCMCI collapse on industrial data β€” they over-predict (1897 edges on TE vs 38 true), giving F1 ~0.04. TCPFN v2.2 is the only method with usable precision + recall together.
### Estimation benchmarks
- Overall PEHE 0.917 | ATE MAE 0.504 | trajectory correlation β‰ˆ 0 β€” **use v3 for CATE**.
### Limitations
- **Sachs regressed** vs v2.1 (AUROC 0.565 vs 0.725). Use v2.1 for cross-sectional biological graphs.
- Estimation degraded β€” trades short-range precision for long-range reach (see scar-tissue entry L-33 in project docs).
---
## v3 β€” estimation champion (experimental)
Tag: `3.0.0-exp-global-std`. Global standardisation fix for the per-group Z-score bias that caps v2.1/v2.2 estimation quality.
- 200K steps.
- **PEHE 0.72** (vs v2.1 0.92 and v2.2 0.92) β€” best of the three on CATE estimation.
- Discovery regressed slightly as trade-off; not yet benchmarked across all 14 discovery datasets β€” **use v2.1 or v2.2 for discovery**.
### Limitations
- Experimental tag β€” standardisation change not yet battle-tested beyond estimation.
- Full benchmark matrix still pending.
---
## Usage
```python
from tcpfn import TemporalCausalAnalyzer
# Discovery on general data (biology, cross-sectional)
analyzer = TemporalCausalAnalyzer(temporal_model="models/temporal/final.pt")
# Industrial / long-range discovery (lags in hours)
analyzer = TemporalCausalAnalyzer(temporal_model="models/v2.2/final.pt")
# Effect estimation (CATE trajectories, PEHE-sensitive work)
analyzer = TemporalCausalAnalyzer(temporal_model="models/v3/final.pt")
report = analyzer.run("sensor_data.csv")
print(report.edges) # causal graph with edge strengths and lags
print(report.summary()) # human-readable summary
result = analyzer.explain_event(
data_path="sensor_data.csv",
target_var="temperature_sensor",
event_time="2025-11-15 14:15",
)
print(result.summary()) # ranked root causes + causal chains
```
## Cross-cutting limitations
- Regime classification is noisy (~0.68 accuracy, high eval variance). Judgment heads are **learned heuristics**, not formal guarantees.
- Low-dim cross-sectional data stitched into pseudo-timeseries is out-of-distribution for v2.2 and v3; use v2.1.
- v3 has not yet been run on the full discovery benchmark suite.
## Paper
Stalupula et al., "Temporal Causal Prior-Data Fitted Networks for Panel Data with Learned Reliability Signals"