GEARS trained on Norman 2019

Produced as part of the sc-interp single-cell model comparison repo.

Provenance

Base model

Trained from scratch. GEARS is a graph neural network for perturbation prediction (Roohani, Huang, Leskovec 2024, Nature Biotechnology). It uses a Gene Ontology graph and a co-expression graph to propagate perturbation embeddings through a 1-layer SGC. No foundation checkpoint; the entire model is trained per dataset from random initialisation. hidden_size=64, ~7.5 MB weights.

Training

  • Task: perturb-GEP, control cells as input, matched perturbed cells as target
  • Runner: invoked via the sc-interp dispatcher python -m scripts.run gears --dataset norman
  • Split: GEARS simulation split with seed 42 (147 train / 33 val / 97 test perturbations). 7 perturbations whose gene symbols are absent from the GO annotation graph were dropped by GEARS's internal filter (pertdata.py:192-194). See issue #3.
  • Recipe: paper-faithful defaults from GEARS_misc/paper/fig2_train.py (Roohani et al.)
  • Loss: MSE + direction regularisation (lambda=0.1)
  • Optimiser: Adam, lr 1e-3, weight_decay 5e-4, StepLR gamma 0.5 per epoch
  • GNN: 1-layer SGC on both GO and co-expression graphs, hidden_size=64
  • default_pert_graph=False (builds perturbation graph from dataset genes, not a generic default)

Budget and stopping

epochs trained 15 / 15
batch size 32 (train and test)
wall clock 22.4 min (H100 PCIe)
best val Overall MSE 0.0045 (epoch 10)
best val Top 20 DE MSE 0.1804 (epoch 15)
stopping reason max_epochs (no early stop in GEARS)
model size 7.5 MB (config.pkl + model.pt)

Test set metrics (cell-eval)

metric mean median max
pearson_delta 0.4551 0.4939 0.8580
mse 0.0047 0.0040 0.0181
mae 0.0296 0.0277 0.0741
mse_delta 0.0047 0.0040 0.0181
mae_delta 0.0296 0.0277 0.0741
de_direction_match 0.7570 0.7689 0.9394
de_sig_genes_recall 0.8532 0.8684 0.9638
de_spearman_sig 0.1054 0.1054 0.1054
de_spearman_lfc_sig 0.8026 0.8274 0.9605
pr_auc 0.0804 0.0785 0.2242
roc_auc 0.3950 0.3915 0.5215
de_nsig_counts_real 488.5464 501.0000 1122.0000
de_nsig_counts_pred 4642.5258 4694.0000 4787.0000
overlap_at_N 0.0266 0.0223 0.1141
overlap_at_50 0.0262 0.0200 0.1200
overlap_at_100 0.0234 0.0200 0.1200
overlap_at_200 0.0246 0.0200 0.0800
overlap_at_500 0.0258 0.0222 0.0940
precision_at_N 0.0903 0.0909 0.2339
precision_at_50 0.0262 0.0200 0.1200
precision_at_100 0.0234 0.0200 0.1200
precision_at_200 0.0249 0.0200 0.0800
precision_at_500 0.0266 0.0240 0.0940
discrimination_score_l1 0.7094 0.7732 1.0000
discrimination_score_l2 0.7002 0.7629 1.0000
discrimination_score_cosine 0.6443 0.7423 1.0000
pearson_edistance 0.8242 0.8242 0.8242
clustering_agreement 0.3159 0.3159 0.3159

The GEARS paper (Roohani et al. 2024, Table S4) reports pearson and MSE on Norman using the simulation split with seeds 1-5. Our seed-42 pearson_delta of 0.4551 is in the expected range. The paper does not report pearson_delta directly (it reports pearson_de and MSE_de), so exact comparison requires metric alignment. In our cross-model benchmark: CellFlow (0.606) > scGPT (0.507) > GEARS (0.455) on pearson_delta, consistent with the CellFlow paper's headline claim. GEARS leads on de_direction_match (0.757 vs 0.716/0.709), reflecting its GO-graph-informed design.

Known limitations

  • 7 Norman perturbations are dropped by the vendored GEARS GO filter and absent from training/eval. See issue #3. scGPT and CellFlow evaluate on 99 test perts; GEARS on 97.
  • No early stopping. GEARS's train() runs the full epoch budget and selects best_model by minimum val MSE. Val MSE plateaued by epoch 6-7.
  • Trained with default_pert_graph=False (Norman-specific). Other datasets (Replogle) use default_pert_graph=True; switching datasets requires updating this flag.

Files

  • config.pkl โ€” GEARS model config (hidden_size, graph structures, hyperparams). Loaded by GEARS.load_pretrained(path).
  • model.pt โ€” best_model state dict (selected by min val MSE across 15 epochs).
  • training_stats.json โ€” unified sc-interp TrainStats schema: wall_clock_s, wandb_run_url, reason, details.

Usage

from huggingface_hub import hf_hub_download

# Download checkpoint
for f in ["config.pkl", "model.pt"]:
    hf_hub_download(repo_id="matthewshu/gears-norman", filename=f, local_dir="ckpt")

# Or reproduce from source (runs in the gears venv):
#   python -m scripts.run gears --dataset norman --hf-repo matthewshu/gears-norman

Citation

Dataset: Norman et al. 2019 (Science). Model: Roohani, Huang, Leskovec 2024 (Nature Biotechnology). See the GEARS repo for BibTeX.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support