|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- scRNA-seq |
|
|
- single-cell |
|
|
- self-supervised-learning |
|
|
- JEPA |
|
|
- biology |
|
|
datasets: |
|
|
- vevotx/Tahoe-100M |
|
|
pretty_name: GeneJEPA (Perceiver JEPA for scRNA-seq) |
|
|
pipeline_tag: feature-extraction |
|
|
--- |
|
|
|
|
|
# GeneJEPA: A Predictive World Model of the Transcriptome |
|
|
|
|
|
**GeneJEPA** is a Joint-Embedding Predictive Architecture (JEPA) trained for self-supervised representation learning on scRNA-seq. It uses a Perceiver-style encoder to handle sparse, high-dimensional gene count vectors and a Fourier-feature tokenizer for numerical tokenization. |
|
|
|
|
|
**Why?** Produce compact cell embeddings you can use for clustering, transfer learning, linear probes, perturbation prediction, and downstream biological tasks. |
|
|
|
|
|
--- |
|
|
|
|
|
## Repository contents |
|
|
|
|
|
This model repo intentionally contains **artifacts only** (no training code): |
|
|
|
|
|
- **`genejepa-epoch=49.ckpt`** — final PyTorch Lightning checkpoint (student encoder + predictor + EMA state, etc.) |
|
|
- **`gene_metadata.parquet`** — mapping between foundation token IDs and gene identifiers used to build the embedding vocab. |
|
|
- **`global_stats.json`** — global `log1p(counts)` normalization stats (`mean`, `std`) computed over a large sample of training data. |
|
|
|
|
|
--- |
|
|
|
|
|
## Model summary |
|
|
|
|
|
- **Backbone:** Perceiver-style encoder over tokenized genes (identity + Fourier features of expression value) |
|
|
- **Latents:** 512 |
|
|
- **Dimensionality:** 768 |
|
|
- **Blocks:** 24 transformer blocks on the latent array |
|
|
- **Heads:** 12 |
|
|
- **Masking:** stochastic, block-wise targets with context complement |
|
|
- **Predictor:** BYOL-style MLP head |
|
|
- **EMA teacher:** maintained during training (for targets) |
|
|
|
|
|
> Default tokenizer Fourier settings: `N_f=64`, `min_freq=0.1`, `max_freq=100.0`, `freq_scale=1.0`. |
|
|
|
|
|
## Download artifacts |
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
ckpt_path = hf_hub_download(repo_id="elonlit/GeneJEPA", |
|
|
filename="genejepa-epoch=49.ckpt") |
|
|
meta_path = hf_hub_download(repo_id="elonlit/GeneJEPA", |
|
|
filename="gene_metadata.parquet") |
|
|
stats_path = hf_hub_download(repo_id="elonlit/GeneJEPA", |
|
|
filename="global_stats.json") |
|
|
``` |
|
|
|
|
|
## Paper and code |
|
|
|
|
|
- Paper (bioRxiv): [GeneJEPA: A Predictive World Model of the Transcriptome](https://doi.org/10.1101/2025.10.14.682378) |
|
|
- GitHub: [BiostateAI/GeneJEPA](https://github.com/BiostateAI/GeneJEPA) |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@article{GeneJEPA2025, |
|
|
title = {GeneJEPA: A Predictive World Model of the Transcriptome}, |
|
|
author = {Litman, E. and Myers, T. and Agarwal, V. and Gopinath, A. and Li, O. and Mittal, E. and Kassis, T.}, |
|
|
journal = {bioRxiv}, |
|
|
year = {2025}, |
|
|
publisher = {Cold Spring Harbor Laboratory}, |
|
|
note = {preprint}, |
|
|
doi = {10.1101/2025.10.14.682378}, |
|
|
} |
|
|
``` |
|
|
|