danielle-miller-sayag's picture
Upload README.md with huggingface_hub
a045f0c verified
---
tags:
- biology
- genomics
- bulk-rna-seq
- patient-embedding
library_name: transformers
license: apache-2.0
---
# Virtual Cell — Distilled Bulk Encoder
A bulk RNA-seq encoder distilled from
[ConvergeBio/virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient).
It maps bulk gene expression directly into the same 512-dimensional patient embedding space,
making single-cell-trained representations accessible when only bulk data is available.
## Model architecture
```
input [batch, 18301 genes]
→ MLP encoder (Linear → BN → PReLU)² → [batch, 512]
```
Training objective: cosine distillation loss, with teacher embeddings produced by
`virtual-cell-patient` on matched single-cell RNA-seq data from the same patients.
## Relationship to virtual-cell-patient
| | [virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient) | virtual-cell-distil-bulk |
|---|---|---|
| Input | `[batch, n_cells, 18301]` single-cell matrix | `[batch, 18301]` bulk expression vector |
| Output | `[batch, 512]` patient embedding + class logits | `[batch, 512]` patient embedding |
| Requires single-cell data | Yes | No |
Both models use the same 18,301-gene vocabulary (`gene_names.txt`) and produce embeddings
in the same 512-dimensional space.
## Installation
```bash
pip install -r requirements.txt
```
`wandb` is optional and only needed when training with `--wandb_project`.
## Quick start
### Inference — extract embeddings
```python
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-distil-bulk",
trust_remote_code=True,
).eval()
x = torch.randn(4, 18_301) # [batch, num_genes]
with torch.no_grad():
out = model(input_ids=x)
print(out["embeddings"].shape) # [4, 512]
```
> **Note:** the model uses BatchNorm — always call `.eval()` for inference.
### Inference on real data
```python
from datasets import load_dataset
import torch
from transformers import AutoModel
ds = load_dataset("ConvergeBio/virtual-cell-distil-bulk-example", split="validation")
model = AutoModel.from_pretrained(
"ConvergeBio/virtual-cell-distil-bulk",
trust_remote_code=True,
).eval()
sample = torch.tensor(ds[0]["bulk_expression"]).unsqueeze(0) # [1, 18301]
with torch.no_grad():
out = model(input_ids=sample)
print(out["embeddings"].shape) # [1, 512]
```
> **Note:** `ConvergeBio/virtual-cell-distil-bulk-example` is a minimal sample dataset
> intended only to verify the data format and run a quick end-to-end check.
> Metrics produced from this dataset should not be interpreted.
## Fine-tuning for classification
The pretrained encoder can be fine-tuned on any bulk RNA-seq classification task.
A linear head is added on top; the encoder weights are initialised from the distilled
checkpoint and optionally frozen.
```python
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
"ConvergeBio/virtual-cell-distil-bulk",
num_labels=2,
ignore_mismatched_sizes=True, # classification head is randomly initialised
trust_remote_code=True,
)
```
**Binary classification (e.g. disease vs. healthy) with frozen encoder:**
```bash
python train.py \
--dataset_path <your_dataset> \
--num_classes 2 \
--freeze_encoder \
--output_dir ./my_binary_model
```
**Multi-class fine-tuning:**
```bash
python train.py \
--dataset_path <your_dataset> \
--num_classes <N> \
--output_dir ./my_finetuned_model \
--num_train_epochs 15 \
--learning_rate 1e-4
```
## Preparing your data
`train.py` expects a HuggingFace dataset with `train` (and optionally `validation`) splits.
Each row represents one patient sample:
| Column | Shape | Type | Description |
|---|---|---|---|
| `bulk_expression` | [18301] | float32 | Log-normalised bulk gene expression, aligned to `gene_names.txt` |
| `labels` | scalar | int | Class index |
Input expression should be library-size normalised (target sum 10,000) and log1p
transformed. The gene axis must be aligned to the 18,301 genes in `gene_names.txt`
missing genes are zero-filled, extra genes are dropped.
For a guide on building this dataset from raw count matrices, see the
[example dataset](https://huggingface.co/datasets/ConvergeBio/virtual-cell-distil-bulk-example).
## Repository contents
| File | Description |
|---|---|
| `modeling_virtual_cell_distil.py` | Full model implementation |
| `config.json` | Architecture config |
| `gene_names.txt` | Ordered list of 18,301 HGNC gene symbols |
| `train.py` | Classification fine-tuning script |
| `requirements.txt` | Python dependencies |
| `model.safetensors` | Pretrained encoder weights |
## Citation
If you use this model, please cite:
```bibtex
@article{convergecell2026,
author = {ConvergeBio},
title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses},
year = {2026},
note = {Preprint available on bioRxiv},
}
```
## License
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE).