arrowfm-base / README.md
ryan-thompson's picture
Upload README.md
49a256f verified
|
Raw
History Blame Contribute Delete
1.59 kB
---
license: mit
library_name: pytorch
tags:
- causal-discovery
- tabular
- directed-acyclic-graphs
- zero-shot
- pytorch
- arxiv:2605.07204
---
# ArrowFM Base
Arrow is a zero-shot foundation model for causal discovery from observational tabular data.
Given a numeric dataset, Arrow predicts edge-existence probabilities and node-order scores, then decodes them into a directed acyclic graph (DAG).
- Paper: [Arrow: A Foundation Model for Causal Discovery](https://arxiv.org/abs/2605.07204)
- Code: [github.com/ryan-thompson/arrowfm](https://github.com/ryan-thompson/arrowfm)
## Usage
Install the Python package:
```bash
python3 -m pip install "git+https://github.com/ryan-thompson/arrowfm.git"
```
Run inference:
```python
import torch
from arrowfm import ArrowPredictor
predictor = ArrowPredictor()
x = torch.randn(100, 10)
adj = predictor.predict_adjacency(x)
```
`x` should have shape `(n, p)` for one dataset or `(batch, n, p)` for batched input. The returned `adj` is a boolean adjacency matrix where `adj[j, k]` indicates a directed edge from variable `j` to variable `k`.
To also return edge probabilities:
```python
adj, p_hat = predictor.predict_adjacency(x, return_p_hat = True)
```
## Citation
```bibtex
@misc{thompson2026arrow,
title = {Arrow: A Foundation Model for Causal Discovery},
author = {Thompson, Ryan and Zhao, He and Steinberg, Daniel M. and Bonilla, Edwin V.},
year = {2026},
eprint = {2605.07204},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2605.07204}
}
```