metadata
license: mit
library_name: pytorch
tags:
- causal-discovery
- tabular
- directed-acyclic-graphs
- zero-shot
- pytorch
- arxiv:2605.07204
ArrowFM Base
Arrow is a zero-shot foundation model for causal discovery from observational tabular data.
Given a numeric dataset, Arrow predicts edge-existence probabilities and node-order scores, then decodes them into a directed acyclic graph (DAG).
Usage
Install the Python package:
python3 -m pip install "git+https://github.com/ryan-thompson/arrowfm.git"
Run inference:
import torch
from arrowfm import ArrowPredictor
predictor = ArrowPredictor()
x = torch.randn(100, 10)
adj = predictor.predict_adjacency(x)
x should have shape (n, p) for one dataset or (batch, n, p) for batched input. The returned adj is a boolean adjacency matrix where adj[j, k] indicates a directed edge from variable j to variable k.
To also return edge probabilities:
adj, p_hat = predictor.predict_adjacency(x, return_p_hat = True)
Citation
@misc{thompson2026arrow,
title = {Arrow: A Foundation Model for Causal Discovery},
author = {Thompson, Ryan and Zhao, He and Steinberg, Daniel M. and Bonilla, Edwin V.},
year = {2026},
eprint = {2605.07204},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2605.07204}
}