arrowfm-base / README.md
ryan-thompson's picture
Upload README.md
49a256f verified
|
raw
history blame contribute delete
1.59 kB
metadata
license: mit
library_name: pytorch
tags:
  - causal-discovery
  - tabular
  - directed-acyclic-graphs
  - zero-shot
  - pytorch
  - arxiv:2605.07204

ArrowFM Base

Arrow is a zero-shot foundation model for causal discovery from observational tabular data.

Given a numeric dataset, Arrow predicts edge-existence probabilities and node-order scores, then decodes them into a directed acyclic graph (DAG).

Usage

Install the Python package:

python3 -m pip install "git+https://github.com/ryan-thompson/arrowfm.git"

Run inference:

import torch
from arrowfm import ArrowPredictor

predictor = ArrowPredictor()
x = torch.randn(100, 10)
adj = predictor.predict_adjacency(x)

x should have shape (n, p) for one dataset or (batch, n, p) for batched input. The returned adj is a boolean adjacency matrix where adj[j, k] indicates a directed edge from variable j to variable k.

To also return edge probabilities:

adj, p_hat = predictor.predict_adjacency(x, return_p_hat = True)

Citation

@misc{thompson2026arrow,
  title         = {Arrow: A Foundation Model for Causal Discovery},
  author        = {Thompson, Ryan and Zhao, He and Steinberg, Daniel M. and Bonilla, Edwin V.},
  year          = {2026},
  eprint        = {2605.07204},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2605.07204}
}