GeneJEPA / README.md
elonlit's picture
Update README.md
b51cac7 verified
metadata
license: mit
language:
  - en
library_name: pytorch
tags:
  - scRNA-seq
  - single-cell
  - self-supervised-learning
  - JEPA
  - biology
datasets:
  - vevotx/Tahoe-100M
pretty_name: GeneJEPA (Perceiver JEPA for scRNA-seq)
pipeline_tag: feature-extraction

GeneJEPA: A Predictive World Model of the Transcriptome

GeneJEPA is a Joint-Embedding Predictive Architecture (JEPA) trained for self-supervised representation learning on scRNA-seq. It uses a Perceiver-style encoder to handle sparse, high-dimensional gene count vectors and a Fourier-feature tokenizer for numerical tokenization.

Why? Produce compact cell embeddings you can use for clustering, transfer learning, linear probes, perturbation prediction, and downstream biological tasks.


Repository contents

This model repo intentionally contains artifacts only (no training code):

  • genejepa-epoch=49.ckpt — final PyTorch Lightning checkpoint (student encoder + predictor + EMA state, etc.)
  • gene_metadata.parquet — mapping between foundation token IDs and gene identifiers used to build the embedding vocab.
  • global_stats.json — global log1p(counts) normalization stats (mean, std) computed over a large sample of training data.

Model summary

  • Backbone: Perceiver-style encoder over tokenized genes (identity + Fourier features of expression value)
  • Latents: 512
  • Dimensionality: 768
  • Blocks: 24 transformer blocks on the latent array
  • Heads: 12
  • Masking: stochastic, block-wise targets with context complement
  • Predictor: BYOL-style MLP head
  • EMA teacher: maintained during training (for targets)

Default tokenizer Fourier settings: N_f=64, min_freq=0.1, max_freq=100.0, freq_scale=1.0.

Download artifacts

from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                            filename="genejepa-epoch=49.ckpt")
meta_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                            filename="gene_metadata.parquet")
stats_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                             filename="global_stats.json")

Paper and code

Citation

@article{GeneJEPA2025,
  title     = {GeneJEPA: A Predictive World Model of the Transcriptome},
  author    = {Litman, E. and Myers, T. and Agarwal, V. and Gopinath, A. and Li, O. and Mittal, E. and Kassis, T.},
  journal   = {bioRxiv},
  year      = {2025},
  publisher = {Cold Spring Harbor Laboratory},
  note      = {preprint},
  doi       = {10.1101/2025.10.14.682378},
}