danielle-miller-sayag's picture
Update PaSCient citation to published Cell Systems paper
cf16572 verified
---
tags:
- biology
- genomics
- single-cell-rna-seq
- patient-classification
library_name: transformers
license: apache-2.0
---
# Virtual Cell — Patient Model
A patient-level disease classification model trained on single-cell RNA-seq data.
Given a matrix of gene expression profiles (one row per cell), the model produces
a disease-category prediction for the patient.
## Model architecture
```
input [batch, num_cells, 18301 genes]
→ MLP cell embedder → [batch, num_cells, 512]
→ Attention aggregator → [batch, 512]
→ Dropout + Linear head → [batch, 10 classes]
```
## Pretrained classification task
The pretrained checkpoint classifies patients into **10 disease categories**:
`oncological`, `immune_inflammatory`, `neurological`, `metabolic_vascular`,
`gastrointestinal`, `respiratory`, `epithelial_barrier`, `sensory_specialized`,
`healthy_control`, `other`.
The pretrained embedder generalizes well to other classification tasks. Common
fine-tuning scenarios include binary sick vs. healthy or treatment response
prediction — see [Fine-tuning](#fine-tuning) below.
## Installation
All repository files are required to run `train.py`. Download them all
(or clone the repo) and install dependencies:
```bash
pip install -r requirements.txt
```
`wandb` is optional and only needed when training with `--wandb_project`.
> **Tip:** `train.py` uses multiple workers for data loading. A machine with
> at least 8 CPU cores is recommended for good throughput — set
> `--num_workers` to match your core count.
## Quick start
### Verify the model loads
```python
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-patient",
trust_remote_code=True,
).eval()
x = torch.randn(1, 500, 18_301) # [batch, num_cells, num_genes]
with torch.no_grad():
out = model(input_ids=x)
print(out.logits.shape) # [1, 10]
print(out.logits.softmax(-1))
```
### Inference on real data
```python
from datasets import load_dataset
import torch
from transformers import AutoModel
ds = load_dataset("ConvergeBio/virtual-cell-patient-example", split="validation")
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-patient",
trust_remote_code=True,
).eval()
sample = torch.tensor(ds[0]["input_ids"]).unsqueeze(0) # [1, 500, 18_301]
with torch.no_grad():
out = model(input_ids=sample)
print(out.logits.softmax(-1))
```
> **Note:** `ConvergeBio/virtual-cell-patient-example` is a minimal sample dataset
> intended only to verify the data format and run a quick end-to-end check. It
> contains a small number of patients and is not representative of a real training
> or evaluation distribution. Metrics produced from inference or training on this
> dataset should not be interpreted.
## Preparing your data
`train.py` expects a HuggingFace dataset with `train` (and optionally `validation`)
splits. Each row represents one cell sample for a patient, with the following
required columns:
| Column | Shape | Type | Description |
|---|---|---|---|
| `input_ids` | [500, 18301] | float32 | Log-normalized gene expression matrix, aligned to `gene_names.txt` |
| `attention_mask` | [500] | bool | Cell mask (all ones for fixed cell count) |
| `labels` | scalar | int | Class index |
| `entity_id` | scalar | int | Patient identifier — groups augmented views of the same patient |
**Augmentation is strongly encouraged** — multiple independent random cell samples
from the same patient should be included as separate rows sharing the same
`entity_id`. At inference, the model averages softmax probabilities across views
for a more robust prediction. A factor of 5 augmentations per patient is a good
default.
For a guide on building this dataset from raw scRNA-seq (h5ad) files, see the
[example dataset](https://huggingface.co/datasets/ConvergeBio/virtual-cell-patient-example).
## Fine-tuning
**Binary classification (e.g. sick vs. healthy):**
```bash
python train.py \
--dataset_path <your_dataset> \
--num_classes 2 \
--freeze_embedder \
--output_dir ./my_binary_model
```
`--freeze_embedder` keeps the pretrained cell embedder frozen and only trains
the new head — recommended when your dataset is small.
**Multi-class fine-tuning on a different label set:**
```bash
python train.py \
--dataset_path <your_dataset> \
--num_classes <N> \
--output_dir ./my_finetuned_model \
--num_train_epochs 15 \
--learning_rate 1e-4
```
## Training from scratch
```bash
python train.py \
--dataset_path <your_dataset> \
--from_scratch \
--output_dir ./my_scratch_model
```
## Repository contents
| File | Description |
|---|---|
| `modeling_virtual_cell.py` | Full model implementation |
| `config.json` | Architecture config |
| `gene_names.txt` | Ordered list of 18,301 HGNC gene symbols |
| `train.py` | Fine-tuning / training script |
| `requirements.txt` | Python dependencies |
| `model.safetensors` | Pretrained weights |
## Citation
If you use this model, please cite:
```bibtex
@article{convergecell2026,
author = {ConvergeBio},
title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses},
year = {2026},
note = {Preprint available on bioRxiv},
}
```
The model architecture and data processing approach were inspired by:
```bibtex
@article{liu2026pascient,
author = {Liu, T. and De Brouwer, E. and Verma, A. and Missarova, A. and
Kuo, T. and others},
title = {Learning multi-cellular representations of single-cell transcriptomics
data enables characterization of patient-level disease states},
journal = {Cell Systems},
volume = {17},
pages = {101570},
year = {2026},
}
```
## License
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE).