--- license: apache-2.0 tags: - Causal-Inference - CausalPFN library_name: causalpfn --- # CausalPFN Model This repository contains the model weights for CausalPFN, a transformer-based in-context learning model for causal effect estimation. ## Model Description CausalPFN is a pre-trained model for amortized causal effect estimation via in-context learning. It allows for accurate estimation of conditional average treatment effects (CATE) and average treatment effects (ATE) without requiring model retraining for each new dataset. The model is based on a transformer architecture with uncertainty quantification and calibration. ## Requirements - Python 3.10+ - PyTorch 2.3+ - NumPy - scikit-learn - tqdm - faiss-cpu - huggingface_hub ## Installation To use this model, install the CausalPFN library: ```bash pip install causalpfn ``` ## Usage You can use this model with the CausalPFN library: ```python import torch from causalpfn import CATEEstimator, ATEEstimator # Create a CATE estimator causalpfn_cate = CATEEstimator( device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), verbose=True, ) # Fit the model on your data # X_train: covariates, T_train: binary treatment, Y_train: observed outcome — from observational data causalpfn_cate.fit(X_train, T_train, Y_train) # Estimate CATE cate_hat = causalpfn_cate.estimate_cate(X_test) # Create an ATE estimator causalpfn_ate = ATEEstimator( device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), verbose=True, ) # Fit and estimate ATE causalpfn_ate.fit(X, T, Y) ate_hat = causalpfn_ate.estimate_ate() ``` ## Citations If you use this model in your research, please cite: ``` @misc{balazadeh2025causalpfn, title={CausalPFN: Amortized Causal Effect Estimation via In-Context Learning}, author={Vahid Balazadeh and Hamidreza Kamkari and Valentin Thomas and Benson Li and Junwei Ma and Jesse C. Cresswell and Rahul G. Krishnan}, year={2025}, eprint={2506.07918}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2506.07918}, } ``` ## License This model is licensed under [Apache-2.0](https://github.com/vdblm/CausalPFN/blob/main/LICENSE).