File size: 2,826 Bytes
e4df113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3abd5a8
e4df113
bded2ac
e4df113
b51cac7
e4df113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404c7a8
e4df113
404c7a8
e4df113
404c7a8
e4df113
 
 
4bd2952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
---
license: mit
language:
- en
library_name: pytorch
tags:
- scRNA-seq
- single-cell
- self-supervised-learning
- JEPA
- biology
datasets:
- vevotx/Tahoe-100M
pretty_name: GeneJEPA (Perceiver JEPA for scRNA-seq)
pipeline_tag: feature-extraction
---

# GeneJEPA: A Predictive World Model of the Transcriptome

**GeneJEPA** is a Joint-Embedding Predictive Architecture (JEPA) trained for self-supervised representation learning on scRNA-seq. It uses a Perceiver-style encoder to handle sparse, high-dimensional gene count vectors and a Fourier-feature tokenizer for numerical tokenization.

**Why?** Produce compact cell embeddings you can use for clustering, transfer learning, linear probes, perturbation prediction, and downstream biological tasks.

---

## Repository contents

This model repo intentionally contains **artifacts only** (no training code):

- **`genejepa-epoch=49.ckpt`** — final PyTorch Lightning checkpoint (student encoder + predictor + EMA state, etc.)
- **`gene_metadata.parquet`** — mapping between foundation token IDs and gene identifiers used to build the embedding vocab.
- **`global_stats.json`** — global `log1p(counts)` normalization stats (`mean`, `std`) computed over a large sample of training data.

---

## Model summary

- **Backbone:** Perceiver-style encoder over tokenized genes (identity + Fourier features of expression value)
- **Latents:** 512  
- **Dimensionality:** 768  
- **Blocks:** 24 transformer blocks on the latent array  
- **Heads:** 12  
- **Masking:** stochastic, block-wise targets with context complement  
- **Predictor:** BYOL-style MLP head  
- **EMA teacher:** maintained during training (for targets)

> Default tokenizer Fourier settings: `N_f=64`, `min_freq=0.1`, `max_freq=100.0`, `freq_scale=1.0`.

## Download artifacts
```python
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                            filename="genejepa-epoch=49.ckpt")
meta_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                            filename="gene_metadata.parquet")
stats_path = hf_hub_download(repo_id="elonlit/GeneJEPA",
                             filename="global_stats.json")
```

## Paper and code

- Paper (bioRxiv): [GeneJEPA: A Predictive World Model of the Transcriptome](https://doi.org/10.1101/2025.10.14.682378)
 - GitHub: [BiostateAI/GeneJEPA](https://github.com/BiostateAI/GeneJEPA)

## Citation

```bibtex
@article{GeneJEPA2025,
  title     = {GeneJEPA: A Predictive World Model of the Transcriptome},
  author    = {Litman, E. and Myers, T. and Agarwal, V. and Gopinath, A. and Li, O. and Mittal, E. and Kassis, T.},
  journal   = {bioRxiv},
  year      = {2025},
  publisher = {Cold Spring Harbor Laboratory},
  note      = {preprint},
  doi       = {10.1101/2025.10.14.682378},
}
```