ArrowFM Base

Arrow is a zero-shot foundation model for causal discovery from observational tabular data.

Given a numeric dataset, Arrow predicts edge-existence probabilities and node-order scores, then decodes them into a directed acyclic graph (DAG).

Usage

Install the Python package:

python3 -m pip install "git+https://github.com/ryan-thompson/arrowfm.git"

Run inference:

import torch
from arrowfm import ArrowPredictor

predictor = ArrowPredictor()
x = torch.randn(100, 10)
adj = predictor.predict_adjacency(x)

x should have shape (n, p) for one dataset or (batch, n, p) for batched input. The returned adj is a boolean adjacency matrix where adj[j, k] indicates a directed edge from variable j to variable k.

To also return edge probabilities:

adj, p_hat = predictor.predict_adjacency(x, return_p_hat = True)

Citation

@misc{thompson2026arrow,
  title         = {Arrow: A Foundation Model for Causal Discovery},
  author        = {Thompson, Ryan and Zhao, He and Steinberg, Daniel M. and Bonilla, Edwin V.},
  year          = {2026},
  eprint        = {2605.07204},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2605.07204}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for ryan-thompson/arrowfm-base