--- tags: - biology - genomics - single-cell-rna-seq - patient-classification library_name: transformers license: apache-2.0 --- # Virtual Cell — Patient Model A patient-level disease classification model trained on single-cell RNA-seq data. Given a matrix of gene expression profiles (one row per cell), the model produces a disease-category prediction for the patient. ## Model architecture ``` input [batch, num_cells, 18301 genes] → MLP cell embedder → [batch, num_cells, 512] → Attention aggregator → [batch, 512] → Dropout + Linear head → [batch, 10 classes] ``` ## Pretrained classification task The pretrained checkpoint classifies patients into **10 disease categories**: `oncological`, `immune_inflammatory`, `neurological`, `metabolic_vascular`, `gastrointestinal`, `respiratory`, `epithelial_barrier`, `sensory_specialized`, `healthy_control`, `other`. The pretrained embedder generalizes well to other classification tasks. Common fine-tuning scenarios include binary sick vs. healthy or treatment response prediction — see [Fine-tuning](#fine-tuning) below. ## Installation All repository files are required to run `train.py`. Download them all (or clone the repo) and install dependencies: ```bash pip install -r requirements.txt ``` `wandb` is optional and only needed when training with `--wandb_project`. > **Tip:** `train.py` uses multiple workers for data loading. A machine with > at least 8 CPU cores is recommended for good throughput — set > `--num_workers` to match your core count. ## Quick start ### Verify the model loads ```python import torch from transformers import AutoModel model = AutoModel.from_pretrained( "ConvergeBio/virtual-cell-patient", trust_remote_code=True, ).eval() x = torch.randn(1, 500, 18_301) # [batch, num_cells, num_genes] with torch.no_grad(): out = model(input_ids=x) print(out.logits.shape) # [1, 10] print(out.logits.softmax(-1)) ``` ### Inference on real data ```python from datasets import load_dataset import torch from transformers import AutoModel ds = load_dataset("ConvergeBio/virtual-cell-patient-example", split="validation") model = AutoModel.from_pretrained( "ConvergeBio/virtual-cell-patient", trust_remote_code=True, ).eval() sample = torch.tensor(ds[0]["input_ids"]).unsqueeze(0) # [1, 500, 18_301] with torch.no_grad(): out = model(input_ids=sample) print(out.logits.softmax(-1)) ``` > **Note:** `ConvergeBio/virtual-cell-patient-example` is a minimal sample dataset > intended only to verify the data format and run a quick end-to-end check. It > contains a small number of patients and is not representative of a real training > or evaluation distribution. Metrics produced from inference or training on this > dataset should not be interpreted. ## Preparing your data `train.py` expects a HuggingFace dataset with `train` (and optionally `validation`) splits. Each row represents one cell sample for a patient, with the following required columns: | Column | Shape | Type | Description | |---|---|---|---| | `input_ids` | [500, 18301] | float32 | Log-normalized gene expression matrix, aligned to `gene_names.txt` | | `attention_mask` | [500] | bool | Cell mask (all ones for fixed cell count) | | `labels` | scalar | int | Class index | | `entity_id` | scalar | int | Patient identifier — groups augmented views of the same patient | **Augmentation is strongly encouraged** — multiple independent random cell samples from the same patient should be included as separate rows sharing the same `entity_id`. At inference, the model averages softmax probabilities across views for a more robust prediction. A factor of 5 augmentations per patient is a good default. For a guide on building this dataset from raw scRNA-seq (h5ad) files, see the [example dataset](https://huggingface.co/datasets/ConvergeBio/virtual-cell-patient-example). ## Fine-tuning **Binary classification (e.g. sick vs. healthy):** ```bash python train.py \ --dataset_path \ --num_classes 2 \ --freeze_embedder \ --output_dir ./my_binary_model ``` `--freeze_embedder` keeps the pretrained cell embedder frozen and only trains the new head — recommended when your dataset is small. **Multi-class fine-tuning on a different label set:** ```bash python train.py \ --dataset_path \ --num_classes \ --output_dir ./my_finetuned_model \ --num_train_epochs 15 \ --learning_rate 1e-4 ``` ## Training from scratch ```bash python train.py \ --dataset_path \ --from_scratch \ --output_dir ./my_scratch_model ``` ## Repository contents | File | Description | |---|---| | `modeling_virtual_cell.py` | Full model implementation | | `config.json` | Architecture config | | `gene_names.txt` | Ordered list of 18,301 HGNC gene symbols | | `train.py` | Fine-tuning / training script | | `requirements.txt` | Python dependencies | | `model.safetensors` | Pretrained weights | ## Citation If you use this model, please cite: ```bibtex @article{convergecell2026, author = {ConvergeBio}, title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses}, year = {2026}, note = {Preprint available on bioRxiv}, } ``` The model architecture and data processing approach were inspired by: ```bibtex @article{liu2026pascient, author = {Liu, T. and De Brouwer, E. and Verma, A. and Missarova, A. and Kuo, T. and others}, title = {Learning multi-cellular representations of single-cell transcriptomics data enables characterization of patient-level disease states}, journal = {Cell Systems}, volume = {17}, pages = {101570}, year = {2026}, } ``` ## License Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE).