--- tags: - biology - genomics - bulk-rna-seq - patient-embedding library_name: transformers license: apache-2.0 --- # Virtual Cell — Distilled Bulk Encoder A bulk RNA-seq encoder distilled from [ConvergeBio/virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient). It maps bulk gene expression directly into the same 512-dimensional patient embedding space, making single-cell-trained representations accessible when only bulk data is available. ## Model architecture ``` input [batch, 18301 genes] → MLP encoder (Linear → BN → PReLU)² → [batch, 512] ``` Training objective: cosine distillation loss, with teacher embeddings produced by `virtual-cell-patient` on matched single-cell RNA-seq data from the same patients. ## Relationship to virtual-cell-patient | | [virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient) | virtual-cell-distil-bulk | |---|---|---| | Input | `[batch, n_cells, 18301]` single-cell matrix | `[batch, 18301]` bulk expression vector | | Output | `[batch, 512]` patient embedding + class logits | `[batch, 512]` patient embedding | | Requires single-cell data | Yes | No | Both models use the same 18,301-gene vocabulary (`gene_names.txt`) and produce embeddings in the same 512-dimensional space. ## Installation ```bash pip install -r requirements.txt ``` `wandb` is optional and only needed when training with `--wandb_project`. ## Quick start ### Inference — extract embeddings ```python import torch from transformers import AutoModel model = AutoModel.from_pretrained( "ConvergeBio/virtual-cell-distil-bulk", trust_remote_code=True, ).eval() x = torch.randn(4, 18_301) # [batch, num_genes] with torch.no_grad(): out = model(input_ids=x) print(out["embeddings"].shape) # [4, 512] ``` > **Note:** the model uses BatchNorm — always call `.eval()` for inference. ### Inference on real data ```python from datasets import load_dataset import torch from transformers import AutoModel ds = load_dataset("ConvergeBio/virtual-cell-distil-bulk-example", split="validation") model = AutoModel.from_pretrained( "ConvergeBio/virtual-cell-distil-bulk", trust_remote_code=True, ).eval() sample = torch.tensor(ds[0]["bulk_expression"]).unsqueeze(0) # [1, 18301] with torch.no_grad(): out = model(input_ids=sample) print(out["embeddings"].shape) # [1, 512] ``` > **Note:** `ConvergeBio/virtual-cell-distil-bulk-example` is a minimal sample dataset > intended only to verify the data format and run a quick end-to-end check. > Metrics produced from this dataset should not be interpreted. ## Fine-tuning for classification The pretrained encoder can be fine-tuned on any bulk RNA-seq classification task. A linear head is added on top; the encoder weights are initialised from the distilled checkpoint and optionally frozen. ```python from transformers import AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained( "ConvergeBio/virtual-cell-distil-bulk", num_labels=2, ignore_mismatched_sizes=True, # classification head is randomly initialised trust_remote_code=True, ) ``` **Binary classification (e.g. disease vs. healthy) with frozen encoder:** ```bash python train.py \ --dataset_path \ --num_classes 2 \ --freeze_encoder \ --output_dir ./my_binary_model ``` **Multi-class fine-tuning:** ```bash python train.py \ --dataset_path \ --num_classes \ --output_dir ./my_finetuned_model \ --num_train_epochs 15 \ --learning_rate 1e-4 ``` ## Preparing your data `train.py` expects a HuggingFace dataset with `train` (and optionally `validation`) splits. Each row represents one patient sample: | Column | Shape | Type | Description | |---|---|---|---| | `bulk_expression` | [18301] | float32 | Log-normalised bulk gene expression, aligned to `gene_names.txt` | | `labels` | scalar | int | Class index | Input expression should be library-size normalised (target sum 10,000) and log1p transformed. The gene axis must be aligned to the 18,301 genes in `gene_names.txt` — missing genes are zero-filled, extra genes are dropped. For a guide on building this dataset from raw count matrices, see the [example dataset](https://huggingface.co/datasets/ConvergeBio/virtual-cell-distil-bulk-example). ## Repository contents | File | Description | |---|---| | `modeling_virtual_cell_distil.py` | Full model implementation | | `config.json` | Architecture config | | `gene_names.txt` | Ordered list of 18,301 HGNC gene symbols | | `train.py` | Classification fine-tuning script | | `requirements.txt` | Python dependencies | | `model.safetensors` | Pretrained encoder weights | ## Citation If you use this model, please cite: ```bibtex @article{convergecell2026, author = {ConvergeBio}, title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses}, year = {2026}, note = {Preprint available on bioRxiv}, } ``` ## License Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE).