File size: 2,229 Bytes
bf26738
ccfc508
bf26738
 
 
5c60f36
 
 
 
 
 
 
 
 
 
4f2a8a8
5c60f36
 
 
ccfc508
 
 
 
 
4f2a8a8
ccfc508
4f2a8a8
5c60f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25d06e7
 
5c60f36
 
25d06e7
 
5c60f36
25d06e7
5c60f36
 
 
 
 
ccfc508
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
---
license: apache-2.0
tags:
- Causal-Inference
- CausalPFN
library_name: causalpfn
---
# CausalPFN Model

This repository contains the model weights for CausalPFN, a transformer-based in-context learning model for causal effect estimation.

## Model Description

CausalPFN is a pre-trained model for amortized causal effect estimation via in-context learning. It allows for accurate estimation of conditional average treatment effects (CATE) and average treatment effects (ATE) without requiring model retraining for each new dataset.

The model is based on a transformer architecture with uncertainty quantification and calibration.

## Requirements


- Python 3.10+
- PyTorch 2.3+
- NumPy
- scikit-learn
- tqdm
- faiss-cpu
- huggingface_hub

## Installation

To use this model, install the CausalPFN library:

```bash
pip install causalpfn
```

## Usage

You can use this model with the CausalPFN library:

```python
import torch
from causalpfn import CATEEstimator, ATEEstimator


# Create a CATE estimator
causalpfn_cate = CATEEstimator(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    verbose=True,
)

# Fit the model on your data
# X_train: covariates, T_train: binary treatment, Y_train: observed outcome — from observational data
causalpfn_cate.fit(X_train, T_train, Y_train)  

# Estimate CATE
cate_hat = causalpfn_cate.estimate_cate(X_test)

# Create an ATE estimator
causalpfn_ate = ATEEstimator(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    verbose=True,
)

# Fit and estimate ATE
causalpfn_ate.fit(X, T, Y)
ate_hat = causalpfn_ate.estimate_ate()
```

## Citations

If you use this model in your research, please cite:

```
@misc{balazadeh2025causalpfn,
      title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning}, 
      author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan},
      year={2025},
      eprint={2506.07918},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.07918}, 
}
```

## License

This model is licensed under [Apache-2.0](https://github.com/vdblm/CausalPFN/blob/main/LICENSE).