|
|
--- |
|
|
license: apache-2.0 |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- biology |
|
|
- genomics |
|
|
- single-cell |
|
|
- transformer |
|
|
- diffusion |
|
|
- foundation-model |
|
|
pipeline_tag: feature-extraction |
|
|
--- |
|
|
|
|
|
<div align="center"> |
|
|
|
|
|
# ScDiVa: Masked Discrete Diffusion for Joint Modeling of Single-Cell Identity and Expression |
|
|
|
|
|
<p align="center"> |
|
|
<img src="https://huggingface.co/warming666/ScDiVa/resolve/main/assets/scDiVa.png" alt="ScDiVa Architecture" width="800"/> |
|
|
</p> |
|
|
|
|
|
[**π arXiv Paper**](https://arxiv.org/abs/2602.03477) | [**π» GitHub Repository**](https://github.com/wangmingxuan666/ScDiVa) | [**π Dataset**](https://huggingface.co/datasets/warming666/ScDiVa) |
|
|
|
|
|
</div> |
|
|
|
|
|
## π Model Summary |
|
|
|
|
|
**ScDiVa** (Single-cell Deep Variational Analysis) is a **94.5M parameter** foundation model pre-trained on **59 million** single-cell transcriptomes. It utilizes a novel **Masked Discrete Diffusion** framework to model gene expression as an unordered set, effectively capturing the complex topology of gene regulatory networks. |
|
|
|
|
|
Unlike traditional autoregressive models, ScDiVa employs a bidirectional Transformer encoder with **SwiGLU** activations, **Rotary Positional Embeddings (RoPE)**, and **RMSNorm**, optimized for: |
|
|
|
|
|
* **Reconstruction** |
|
|
* **Cell Type Annotation** |
|
|
* **Multi-batch Integration** |
|
|
* **Gene Perturbation Prediction** |
|
|
* **Gene Regulatory Network (GRN) Inference** |
|
|
|
|
|
## ποΈ Model Specifications |
|
|
|
|
|
| Attribute | Value | |
|
|
| :--- | :--- | |
|
|
| **Parameters** | ~94.5M | |
|
|
| **Layers** | 12 | |
|
|
| **Hidden Size** | 512 | |
|
|
| **Attention Heads** | 8 | |
|
|
| **Max Sequence Length** | 1,200 genes | |
|
|
| **Vocabulary** | 41,818 genes | |
|
|
| **Training Objective** | Dual Denoising (Identity Classification + Value Regression) | |
|
|
|
|
|
--- |
|
|
|
|
|
## π Quick Start |
|
|
|
|
|
To use ScDiVa, you need the `modeling_scdiva.py` file (included in this repository). |
|
|
|
|
|
### 1. Installation |
|
|
|
|
|
```bash |
|
|
pip install torch numpy huggingface_hub |
|
|
|
|
|
``` |
|
|
|
|
|
### 2. Loading the Pre-trained Model |
|
|
|
|
|
You can load the model directly using the `from_pretrained` method defined in our architecture. |
|
|
|
|
|
```python |
|
|
from modeling_scdiva import ScDiVaModel |
|
|
import torch |
|
|
|
|
|
# Load the model directly from Hugging Face |
|
|
# This will automatically download model.safetensors and config |
|
|
model = ScDiVaModel.from_pretrained("warming666/ScDiVa") |
|
|
model.eval() |
|
|
|
|
|
# Move to GPU if available |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
|
|
|
print(f"β
ScDiVa loaded successfully on {device}") |
|
|
|
|
|
``` |
|
|
|
|
|
### 3. Basic Inference Example |
|
|
|
|
|
```python |
|
|
# Create a dummy input (Batch Size: 2, Num Genes: 41818) |
|
|
# In practice, replace this with your normalized gene expression matrix |
|
|
input_data = torch.randn(2, 41818).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
# Get latent embeddings (for clustering/integration) |
|
|
outputs = model.encode(input_data) |
|
|
embeddings = outputs['latent'] |
|
|
print(f"Latent Embedding Shape: {embeddings.shape}") # [2, 128] |
|
|
|
|
|
# Get annotation logits |
|
|
predictions = model.predict(input_data, task="annotation") |
|
|
print(f"Annotation Logits Shape: {predictions.shape}") # [2, 100] |
|
|
|
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## π Repository Structure |
|
|
|
|
|
This repository contains the core pre-trained weights and fine-tuned checkpoints for downstream tasks. |
|
|
|
|
|
```text |
|
|
warming666/ScDiVa |
|
|
βββ config.json # Model configuration |
|
|
βββ model.safetensors # π₯ Pre-trained Base Weights (94.5M) |
|
|
βββ modeling_scdiva.py # Model architecture definition code |
|
|
βββ downstream/ # π Fine-tuned Checkpoints |
|
|
βββ Multi-batch_Integration/ |
|
|
β βββ immune.pt |
|
|
β βββ pbmc12k.pt |
|
|
β βββ ... |
|
|
βββ Annotation_FT/ # Fine-tuned for specific tissues |
|
|
β βββ hpancreas.pt |
|
|
β βββ ms.pt |
|
|
βββ Annotation_Zeroshot/ # Weights for zero-shot projection |
|
|
βββ Perturbation/ # Weights for gene perturbation tasks |
|
|
|
|
|
``` |
|
|
|
|
|
To load a specific downstream model (e.g., for Batch Integration on Immune dataset), you can download the specific `.pt` file from the `downstream` folder and load it using `torch.load()`. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Benchmarks |
|
|
|
|
|
ScDiVa achieves state-of-the-art performance across multiple benchmarks: |
|
|
|
|
|
* **Batch Integration**: Top-tier performance on PBMC12k (Avg-Bio: **0.9566**) and BMMC datasets. |
|
|
* **Annotation**: **98.6%** accuracy on hPancreas fine-tuning; **91.4%** average accuracy on zero-shot tasks. |
|
|
* **Perturbation**: Pearson correlation of **0.837** on Adamson dataset. |
|
|
|
|
|
For detailed results, please refer to our [arXiv paper](https://www.google.com/url?sa=E&source=gmail&q=https://arxiv.org/abs/2602.03477). |
|
|
|
|
|
--- |
|
|
|
|
|
## β οΈ Limitations & Bias |
|
|
|
|
|
* **Input Normalization**: The model expects log-normalized gene expression data. Raw counts may lead to suboptimal performance. |
|
|
* **Gene Vocabulary**: Inputs must be aligned to the specific 41,818 gene vocabulary used during pre-training. |
|
|
* **Not for Clinical Use**: This model is for research purposes only and has not been validated for clinical diagnosis or treatment. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Citation |
|
|
|
|
|
If you use ScDiVa in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{wang2026scdiva, |
|
|
title={ScDiva: Masked Discrete Diffusion for Joint Modeling of Single-Cell Identity and Expression}, |
|
|
author={Wang, Mingxuan and Chen, Cheng and Jiang, Gaoyang and Ren, Zijia and Zhao, Chuangxin and Shi, Lu and Ma, Yanbiao}, |
|
|
journal={arXiv preprint arXiv:2602.03477}, |
|
|
year={2026} |
|
|
} |
|
|
|
|
|
``` |
|
|
|
|
|
<div align="center"> |
|
|
<sub>Thank you to everyone who has helped me.</sub> |
|
|
</div> |
|
|
|
|
|
``` |
|
|
|
|
|
``` |