GeneJEPA / README.md
elonlit's picture
Update README.md
b51cac7 verified
---
license: mit
language:
- en
library_name: pytorch
tags:
- scRNA-seq
- single-cell
- self-supervised-learning
- JEPA
- biology
datasets:
- vevotx/Tahoe-100M
pretty_name: GeneJEPA (Perceiver JEPA for scRNA-seq)
pipeline_tag: feature-extraction
---
# GeneJEPA: A Predictive World Model of the Transcriptome
**GeneJEPA** is a Joint-Embedding Predictive Architecture (JEPA) trained for self-supervised representation learning on scRNA-seq. It uses a Perceiver-style encoder to handle sparse, high-dimensional gene count vectors and a Fourier-feature tokenizer for numerical tokenization.
**Why?** Produce compact cell embeddings you can use for clustering, transfer learning, linear probes, perturbation prediction, and downstream biological tasks.
---
## Repository contents
This model repo intentionally contains **artifacts only** (no training code):
- **`genejepa-epoch=49.ckpt`** — final PyTorch Lightning checkpoint (student encoder + predictor + EMA state, etc.)
- **`gene_metadata.parquet`** — mapping between foundation token IDs and gene identifiers used to build the embedding vocab.
- **`global_stats.json`** — global `log1p(counts)` normalization stats (`mean`, `std`) computed over a large sample of training data.
---
## Model summary
- **Backbone:** Perceiver-style encoder over tokenized genes (identity + Fourier features of expression value)
- **Latents:** 512
- **Dimensionality:** 768
- **Blocks:** 24 transformer blocks on the latent array
- **Heads:** 12
- **Masking:** stochastic, block-wise targets with context complement
- **Predictor:** BYOL-style MLP head
- **EMA teacher:** maintained during training (for targets)
> Default tokenizer Fourier settings: `N_f=64`, `min_freq=0.1`, `max_freq=100.0`, `freq_scale=1.0`.
## Download artifacts
```python
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="genejepa-epoch=49.ckpt")
meta_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="gene_metadata.parquet")
stats_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
filename="global_stats.json")
```
## Paper and code
- Paper (bioRxiv): [GeneJEPA: A Predictive World Model of the Transcriptome](https://doi.org/10.1101/2025.10.14.682378)
- GitHub: [BiostateAI/GeneJEPA](https://github.com/BiostateAI/GeneJEPA)
## Citation
```bibtex
@article{GeneJEPA2025,
title = {GeneJEPA: A Predictive World Model of the Transcriptome},
author = {Litman, E. and Myers, T. and Agarwal, V. and Gopinath, A. and Li, O. and Mittal, E. and Kassis, T.},
journal = {bioRxiv},
year = {2025},
publisher = {Cold Spring Harbor Laboratory},
note = {preprint},
doi = {10.1101/2025.10.14.682378},
}
```