| --- |
| license: mit |
| library_name: pytorch |
| tags: |
| - causal-discovery |
| - tabular |
| - directed-acyclic-graphs |
| - zero-shot |
| - pytorch |
| - arxiv:2605.07204 |
| --- |
| |
| # ArrowFM Base |
|
|
| Arrow is a zero-shot foundation model for causal discovery from observational tabular data. |
|
|
| Given a numeric dataset, Arrow predicts edge-existence probabilities and node-order scores, then decodes them into a directed acyclic graph (DAG). |
|
|
| - Paper: [Arrow: A Foundation Model for Causal Discovery](https://arxiv.org/abs/2605.07204) |
| - Code: [github.com/ryan-thompson/arrowfm](https://github.com/ryan-thompson/arrowfm) |
|
|
| ## Usage |
|
|
| Install the Python package: |
|
|
| ```bash |
| python3 -m pip install "git+https://github.com/ryan-thompson/arrowfm.git" |
| ``` |
|
|
| Run inference: |
|
|
| ```python |
| import torch |
| from arrowfm import ArrowPredictor |
| |
| predictor = ArrowPredictor() |
| x = torch.randn(100, 10) |
| adj = predictor.predict_adjacency(x) |
| ``` |
|
|
| `x` should have shape `(n, p)` for one dataset or `(batch, n, p)` for batched input. The returned `adj` is a boolean adjacency matrix where `adj[j, k]` indicates a directed edge from variable `j` to variable `k`. |
|
|
| To also return edge probabilities: |
|
|
| ```python |
| adj, p_hat = predictor.predict_adjacency(x, return_p_hat = True) |
| ``` |
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{thompson2026arrow, |
| title = {Arrow: A Foundation Model for Causal Discovery}, |
| author = {Thompson, Ryan and Zhao, He and Steinberg, Daniel M. and Bonilla, Edwin V.}, |
| year = {2026}, |
| eprint = {2605.07204}, |
| archivePrefix = {arXiv}, |
| primaryClass = {cs.LG}, |
| url = {https://arxiv.org/abs/2605.07204} |
| } |
| ``` |
|
|