|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- Causal-Inference |
|
|
- CausalPFN |
|
|
library_name: causalpfn |
|
|
--- |
|
|
# CausalPFN Model |
|
|
|
|
|
This repository contains the model weights for CausalPFN, a transformer-based in-context learning model for causal effect estimation. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
CausalPFN is a pre-trained model for amortized causal effect estimation via in-context learning. It allows for accurate estimation of conditional average treatment effects (CATE) and average treatment effects (ATE) without requiring model retraining for each new dataset. |
|
|
|
|
|
The model is based on a transformer architecture with uncertainty quantification and calibration. |
|
|
|
|
|
## Requirements |
|
|
|
|
|
|
|
|
- Python 3.10+ |
|
|
- PyTorch 2.3+ |
|
|
- NumPy |
|
|
- scikit-learn |
|
|
- tqdm |
|
|
- faiss-cpu |
|
|
- huggingface_hub |
|
|
|
|
|
## Installation |
|
|
|
|
|
To use this model, install the CausalPFN library: |
|
|
|
|
|
```bash |
|
|
pip install causalpfn |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
You can use this model with the CausalPFN library: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from causalpfn import CATEEstimator, ATEEstimator |
|
|
|
|
|
|
|
|
# Create a CATE estimator |
|
|
causalpfn_cate = CATEEstimator( |
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
# Fit the model on your data |
|
|
# X_train: covariates, T_train: binary treatment, Y_train: observed outcome — from observational data |
|
|
causalpfn_cate.fit(X_train, T_train, Y_train) |
|
|
|
|
|
# Estimate CATE |
|
|
cate_hat = causalpfn_cate.estimate_cate(X_test) |
|
|
|
|
|
# Create an ATE estimator |
|
|
causalpfn_ate = ATEEstimator( |
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
# Fit and estimate ATE |
|
|
causalpfn_ate.fit(X, T, Y) |
|
|
ate_hat = causalpfn_ate.estimate_ate() |
|
|
``` |
|
|
|
|
|
## Citations |
|
|
|
|
|
If you use this model in your research, please cite: |
|
|
|
|
|
``` |
|
|
@misc{balazadeh2025causalpfn, |
|
|
title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning}, |
|
|
author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan}, |
|
|
year={2025}, |
|
|
eprint={2506.07918}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.LG}, |
|
|
url={https://arxiv.org/abs/2506.07918}, |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model is licensed under [Apache-2.0](https://github.com/vdblm/CausalPFN/blob/main/LICENSE). |