danielle-miller-sayag commited on
Commit
1e9aaa3
·
verified ·
1 Parent(s): 572c29f

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual Cell — Distilled Bulk Encoder
2
+
3
+ A bulk RNA-seq encoder distilled from
4
+ [ConvergeBio/virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient).
5
+ It maps bulk gene expression directly into the same 512-dimensional patient embedding space,
6
+ making single-cell-trained representations accessible when only bulk data is available.
7
+
8
+ ## Model architecture
9
+
10
+ ```
11
+ input [batch, 18301 genes]
12
+ → MLP encoder (Linear → BN → PReLU)² → [batch, 512]
13
+ ```
14
+
15
+ Training objective: cosine distillation loss, with teacher embeddings produced by
16
+ `virtual-cell-patient` on matched single-cell RNA-seq data from the same patients.
17
+
18
+ ## Relationship to virtual-cell-patient
19
+
20
+ | | [virtual-cell-patient](https://huggingface.co/ConvergeBio/virtual-cell-patient) | virtual-cell-distil-bulk |
21
+ |---|---|---|
22
+ | Input | `[batch, n_cells, 18301]` single-cell matrix | `[batch, 18301]` bulk expression vector |
23
+ | Output | `[batch, 512]` patient embedding + class logits | `[batch, 512]` patient embedding |
24
+ | Requires single-cell data | Yes | No |
25
+
26
+ Both models use the same 18,301-gene vocabulary (`gene_names.txt`) and produce embeddings
27
+ in the same 512-dimensional space.
28
+
29
+ ## Installation
30
+
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ `wandb` is optional and only needed when training with `--wandb_project`.
36
+
37
+ ## Quick start
38
+
39
+ ### Inference — extract embeddings
40
+
41
+ ```python
42
+ import torch
43
+ from transformers import AutoModel
44
+
45
+ model = AutoModel.from_pretrained(
46
+ "ConvergeBio/virtual-cell-distil-bulk",
47
+ trust_remote_code=True,
48
+ ).eval()
49
+
50
+ x = torch.randn(4, 18_301) # [batch, num_genes]
51
+ with torch.no_grad():
52
+ out = model(input_ids=x)
53
+
54
+ print(out["embeddings"].shape) # [4, 512]
55
+ ```
56
+
57
+ > **Note:** the model uses BatchNorm — always call `.eval()` for inference.
58
+
59
+ ### Inference on real data
60
+
61
+ ```python
62
+ from datasets import load_dataset
63
+ import torch
64
+ from transformers import AutoModel
65
+
66
+ ds = load_dataset("ConvergeBio/virtual-cell-distil-bulk-example", token="...", split="validation")
67
+
68
+ model = AutoModel.from_pretrained(
69
+ "ConvergeBio/virtual-cell-distil-bulk",
70
+ trust_remote_code=True,
71
+ ).eval()
72
+
73
+ sample = torch.tensor(ds[0]["bulk_expression"]).unsqueeze(0) # [1, 18301]
74
+ with torch.no_grad():
75
+ out = model(input_ids=sample)
76
+
77
+ print(out["embeddings"].shape) # [1, 512]
78
+ ```
79
+
80
+ > **Note:** `ConvergeBio/virtual-cell-distil-bulk-example` is a minimal sample dataset
81
+ > intended only to verify the data format and run a quick end-to-end check.
82
+ > Metrics produced from this dataset should not be interpreted.
83
+
84
+ ## Fine-tuning for classification
85
+
86
+ The pretrained encoder can be fine-tuned on any bulk RNA-seq classification task.
87
+ A linear head is added on top; the encoder weights are initialised from the distilled
88
+ checkpoint and optionally frozen.
89
+
90
+ ```python
91
+ from transformers import AutoModelForSequenceClassification
92
+
93
+ model = AutoModelForSequenceClassification.from_pretrained(
94
+ "ConvergeBio/virtual-cell-distil-bulk",
95
+ num_labels=2,
96
+ ignore_mismatched_sizes=True, # classification head is randomly initialised
97
+ trust_remote_code=True,
98
+ )
99
+ ```
100
+
101
+ **Binary classification (e.g. disease vs. healthy) with frozen encoder:**
102
+
103
+ ```bash
104
+ python train.py \
105
+ --dataset_path <your_dataset> \
106
+ --num_classes 2 \
107
+ --freeze_encoder \
108
+ --output_dir ./my_binary_model
109
+ ```
110
+
111
+ **Multi-class fine-tuning:**
112
+
113
+ ```bash
114
+ python train.py \
115
+ --dataset_path <your_dataset> \
116
+ --num_classes <N> \
117
+ --output_dir ./my_finetuned_model \
118
+ --num_train_epochs 15 \
119
+ --learning_rate 1e-4
120
+ ```
121
+
122
+ ## Preparing your data
123
+
124
+ `train.py` expects a HuggingFace dataset with `train` (and optionally `validation`) splits.
125
+ Each row represents one patient sample:
126
+
127
+ | Column | Shape | Type | Description |
128
+ |---|---|---|---|
129
+ | `bulk_expression` | [18301] | float32 | Log-normalised bulk gene expression, aligned to `gene_names.txt` |
130
+ | `labels` | scalar | int | Class index |
131
+
132
+ Input expression should be library-size normalised (target sum 10,000) and log1p
133
+ transformed. The gene axis must be aligned to the 18,301 genes in `gene_names.txt` —
134
+ missing genes are zero-filled, extra genes are dropped.
135
+
136
+ For a guide on building this dataset from raw count matrices, see the
137
+ [example dataset](https://huggingface.co/datasets/ConvergeBio/virtual-cell-distil-bulk-example).
138
+
139
+ ## Repository contents
140
+
141
+ | File | Description |
142
+ |---|---|
143
+ | `modeling_virtual_cell_distil.py` | Full model implementation |
144
+ | `config.json` | Architecture config |
145
+ | `gene_names.txt` | Ordered list of 18,301 HGNC gene symbols |
146
+ | `train.py` | Classification fine-tuning script |
147
+ | `requirements.txt` | Python dependencies |
148
+ | `model.safetensors` | Pretrained encoder weights |
149
+
150
+ ## Citation
151
+
152
+ If you use this model, please cite:
153
+
154
+ ```bibtex
155
+ @article{convergecell2026,
156
+ author = {ConvergeBio},
157
+ title = {ConvergeCELL: An end-to-end platform from patient transcriptomics to therapeutic hypotheses},
158
+ year = {2026},
159
+ note = {Preprint available on bioRxiv},
160
+ }
161
+ ```
162
+
163
+ ## License
164
+
165
+ [TBD]
__pycache__/modeling_virtual_cell_distil.cpython-312.pyc ADDED
Binary file (9.05 kB). View file
 
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "virtual_cell_distil",
3
+ "n_genes": 18301,
4
+ "output_dim": 512,
5
+ "hidden_dim": [512, 512],
6
+ "dropout": 0.2044838332376416,
7
+ "residual": false,
8
+ "activation": "prelu",
9
+ "num_labels": 2,
10
+ "classifier_dropout": 0.1,
11
+ "auto_map": {
12
+ "AutoConfig": "modeling_virtual_cell_distil.VirtualCellDistilConfig",
13
+ "AutoModel": "modeling_virtual_cell_distil.VirtualCellDistilModel",
14
+ "AutoModelForSequenceClassification": "modeling_virtual_cell_distil.VirtualCellDistilForSequenceClassification"
15
+ }
16
+ }
gene_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b41cdc6ccc9caded37f4fb68e9d511aaeea77ef0e2f685eabc53edc7cdd060b8
3
+ size 39601856
modeling_virtual_cell_distil.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Virtual Cell — Distilled Bulk Encoder — HuggingFace release.
3
+
4
+ Encodes bulk RNA-seq gene expression into the same 512-d patient embedding
5
+ space as ConvergeBio/virtual-cell-patient, without requiring single-cell data.
6
+ Trained by cosine distillation against patient model embeddings.
7
+
8
+ Two classes are provided:
9
+
10
+ VirtualCellDistilModel
11
+ Pure encoder. Returns 512-d embeddings for each sample.
12
+ Use this for clustering, visualisation, or as a frozen backbone.
13
+
14
+ VirtualCellDistilForSequenceClassification
15
+ Adds a dropout + linear classification head on top of the encoder.
16
+ Load the pretrained encoder weights and fine-tune on your labels.
17
+
18
+ Usage — inference:
19
+ from transformers import AutoModel
20
+ model = AutoModel.from_pretrained(
21
+ "ConvergeBio/virtual-cell-distil-bulk", trust_remote_code=True
22
+ ).eval()
23
+ out = model(input_ids=x) # out["embeddings"]: [batch, 512]
24
+
25
+ Usage — classification fine-tuning:
26
+ from transformers import AutoModelForSequenceClassification
27
+ model = AutoModelForSequenceClassification.from_pretrained(
28
+ "ConvergeBio/virtual-cell-distil-bulk",
29
+ num_labels=2,
30
+ ignore_mismatched_sizes=True, # head is randomly initialised
31
+ trust_remote_code=True,
32
+ )
33
+ out = model(input_ids=x, labels=y)
34
+ # out["loss"], out["logits"], out["embeddings"]
35
+
36
+ Note: the model contains BatchNorm layers — always call .eval() for inference.
37
+ """
38
+
39
+ from typing import List, Optional
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+ import torch.nn.functional as F
44
+ from transformers import PretrainedConfig, PreTrainedModel
45
+
46
+
47
+ def _get_activation(activation: str) -> nn.Module:
48
+ if activation == "prelu":
49
+ return nn.PReLU()
50
+ elif activation == "relu":
51
+ return nn.ReLU()
52
+ elif activation == "gelu":
53
+ return nn.GELU()
54
+ elif activation == "tanh":
55
+ return nn.Tanh()
56
+ raise ValueError(f"Unsupported activation: {activation!r}")
57
+
58
+
59
+ class MLP(nn.Module):
60
+ def __init__(
61
+ self,
62
+ input_dim: int,
63
+ output_dim: int = 512,
64
+ hidden_dim: Optional[List[int]] = None,
65
+ dropout: float = 0.0,
66
+ residual: bool = False,
67
+ activation: str = "prelu",
68
+ ):
69
+ super().__init__()
70
+ if hidden_dim is None:
71
+ hidden_dim = [512, 512]
72
+ self.latent_dim = output_dim
73
+ self.residual = residual
74
+ self.network = nn.ModuleList()
75
+
76
+ if residual:
77
+ assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal"
78
+
79
+ for i in range(len(hidden_dim)):
80
+ if i == 0:
81
+ self.network.append(nn.Sequential(
82
+ nn.Linear(input_dim, hidden_dim[i]),
83
+ nn.BatchNorm1d(hidden_dim[i]),
84
+ _get_activation(activation),
85
+ ))
86
+ else:
87
+ self.network.append(nn.Sequential(
88
+ nn.Dropout(p=dropout),
89
+ nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
90
+ nn.BatchNorm1d(hidden_dim[i]),
91
+ _get_activation(activation),
92
+ ))
93
+ self.network.append(nn.Linear(hidden_dim[-1], output_dim))
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ for i, layer in enumerate(self.network):
97
+ if self.residual and (0 < i < len(self.network) - 1):
98
+ x = layer(x) + x
99
+ else:
100
+ x = layer(x)
101
+ return x
102
+
103
+
104
+ class VirtualCellDistilConfig(PretrainedConfig):
105
+ model_type = "virtual_cell_distil"
106
+
107
+ def __init__(
108
+ self,
109
+ n_genes: int = 18301,
110
+ output_dim: int = 512,
111
+ hidden_dim: Optional[List[int]] = None,
112
+ dropout: float = 0.0,
113
+ residual: bool = False,
114
+ activation: str = "prelu",
115
+ num_labels: int = 2,
116
+ classifier_dropout: float = 0.1,
117
+ **kwargs,
118
+ ):
119
+ super().__init__(**kwargs)
120
+ self.n_genes = n_genes
121
+ self.output_dim = output_dim
122
+ self.hidden_dim = hidden_dim if hidden_dim is not None else [512, 512]
123
+ self.dropout = dropout
124
+ self.residual = residual
125
+ self.activation = activation
126
+ self.num_labels = num_labels
127
+ self.classifier_dropout = classifier_dropout
128
+
129
+
130
+ class VirtualCellDistilModel(PreTrainedModel):
131
+ """Pure encoder — returns 512-d patient embeddings from bulk expression."""
132
+ config_class = VirtualCellDistilConfig
133
+
134
+ def __init__(self, config: VirtualCellDistilConfig):
135
+ super().__init__(config)
136
+ self.encoder = MLP(
137
+ input_dim=config.n_genes,
138
+ output_dim=config.output_dim,
139
+ hidden_dim=config.hidden_dim,
140
+ dropout=config.dropout,
141
+ residual=config.residual,
142
+ activation=config.activation,
143
+ )
144
+
145
+ def forward(self, input_ids: torch.Tensor, **kwargs) -> dict:
146
+ return {"embeddings": self.encoder(input_ids)}
147
+
148
+
149
+ class VirtualCellDistilForSequenceClassification(PreTrainedModel):
150
+ """
151
+ Encoder + linear classification head.
152
+
153
+ The encoder is initialised from pretrained distilled weights.
154
+ The classification head is randomly initialised and trained on your labels.
155
+ Use ignore_mismatched_sizes=True when loading from the pretrained checkpoint.
156
+ """
157
+ config_class = VirtualCellDistilConfig
158
+
159
+ def __init__(self, config: VirtualCellDistilConfig):
160
+ super().__init__(config)
161
+ self.encoder = MLP(
162
+ input_dim=config.n_genes,
163
+ output_dim=config.output_dim,
164
+ hidden_dim=config.hidden_dim,
165
+ dropout=config.dropout,
166
+ residual=config.residual,
167
+ activation=config.activation,
168
+ )
169
+ self.dropout = nn.Dropout(config.classifier_dropout)
170
+ self.classifier = nn.Linear(config.output_dim, config.num_labels)
171
+
172
+ def forward(
173
+ self,
174
+ input_ids: torch.Tensor,
175
+ labels: Optional[torch.Tensor] = None,
176
+ **kwargs,
177
+ ) -> dict:
178
+ embeddings = self.encoder(input_ids)
179
+ logits = self.classifier(self.dropout(embeddings))
180
+ loss = None
181
+ if labels is not None:
182
+ loss = F.cross_entropy(logits, labels)
183
+ return {"loss": loss, "logits": logits, "embeddings": embeddings}
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.40,<5.0
3
+ accelerate>=0.26
4
+ datasets>=2.19
5
+ scikit-learn>=1.3
6
+ numpy>=1.24
7
+ safetensors>=0.4
8
+
9
+ # optional: only needed with --wandb_project
10
+ # wandb
train.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from datasets import DatasetDict, load_dataset
10
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
11
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
12
+ from transformers.trainer_utils import EvalPrediction
13
+
14
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
+ from modeling_virtual_cell_distil import (
16
+ VirtualCellDistilConfig,
17
+ VirtualCellDistilForSequenceClassification,
18
+ )
19
+
20
+
21
+ @dataclass
22
+ class BulkCollator:
23
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
24
+ return {
25
+ "input_ids": torch.stack([
26
+ torch.tensor(f["bulk_expression"], dtype=torch.float32) for f in features
27
+ ]),
28
+ "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long),
29
+ }
30
+
31
+
32
+ def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
33
+ logits = eval_pred.predictions
34
+ if isinstance(logits, tuple):
35
+ logits = logits[0]
36
+ labels = eval_pred.label_ids
37
+ preds = np.argmax(logits, axis=1)
38
+ return {
39
+ "accuracy": accuracy_score(labels, preds),
40
+ "f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
41
+ "precision": precision_score(labels, preds, average="macro", zero_division=0),
42
+ "recall": recall_score(labels, preds, average="macro", zero_division=0),
43
+ }
44
+
45
+
46
+ def parse_args():
47
+ p = argparse.ArgumentParser()
48
+ p.add_argument("--dataset_path", required=True,
49
+ help="HF dataset ID or local path with train (and optionally validation) splits")
50
+ p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-distil-bulk")
51
+ p.add_argument("--hf_token", default=None)
52
+ p.add_argument("--output_dir", default="./vc_distil_output")
53
+ p.add_argument("--num_classes", type=int, default=None)
54
+ p.add_argument("--freeze_encoder", action="store_true",
55
+ help="Freeze the pretrained encoder and train the classification head only")
56
+ p.add_argument("--num_train_epochs", type=int, default=15)
57
+ p.add_argument("--per_device_train_batch_size", type=int, default=32)
58
+ p.add_argument("--per_device_eval_batch_size", type=int, default=32)
59
+ p.add_argument("--learning_rate", type=float, default=1e-4)
60
+ p.add_argument("--weight_decay", type=float, default=0.05)
61
+ p.add_argument("--warmup_ratio", type=float, default=0.1)
62
+ p.add_argument("--lr_scheduler_type", default="cosine")
63
+ p.add_argument("--patience", type=int, default=5)
64
+ p.add_argument("--num_workers", type=int, default=4)
65
+ p.add_argument("--prefetch_factor", type=int, default=2)
66
+ p.add_argument("--wandb_project", default=None)
67
+ p.add_argument("--run_name", default=None)
68
+ return p.parse_args()
69
+
70
+
71
+ def main():
72
+ args = parse_args()
73
+
74
+ if os.path.isdir(args.dataset_path):
75
+ ds = DatasetDict.load_from_disk(args.dataset_path)
76
+ else:
77
+ ds = load_dataset(args.dataset_path,
78
+ num_proc=args.num_workers or None,
79
+ token=args.hf_token or True)
80
+ train_ds = ds["train"]
81
+ val_ds: Optional[object] = ds.get("validation")
82
+
83
+ hf_kwargs = {"trust_remote_code": True}
84
+ if args.hf_token:
85
+ hf_kwargs["token"] = args.hf_token
86
+
87
+ config = VirtualCellDistilConfig.from_pretrained(args.model_name_or_path, **hf_kwargs)
88
+ if args.num_classes is not None:
89
+ config.num_labels = args.num_classes
90
+ config.id2label = {str(i): str(i) for i in range(args.num_classes)}
91
+ config.label2id = {str(i): i for i in range(args.num_classes)}
92
+
93
+ model = VirtualCellDistilForSequenceClassification.from_pretrained(
94
+ args.model_name_or_path,
95
+ config=config,
96
+ ignore_mismatched_sizes=True,
97
+ **hf_kwargs,
98
+ )
99
+
100
+ if args.freeze_encoder:
101
+ for param in model.encoder.parameters():
102
+ param.requires_grad = False
103
+
104
+ if args.wandb_project:
105
+ os.environ["WANDB_PROJECT"] = args.wandb_project
106
+
107
+ has_val = val_ds is not None
108
+ training_args = TrainingArguments(
109
+ output_dir=args.output_dir,
110
+ num_train_epochs=args.num_train_epochs,
111
+ per_device_train_batch_size=args.per_device_train_batch_size,
112
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
113
+ learning_rate=args.learning_rate,
114
+ weight_decay=args.weight_decay,
115
+ warmup_ratio=args.warmup_ratio,
116
+ lr_scheduler_type=args.lr_scheduler_type,
117
+ eval_strategy="epoch" if has_val else "no",
118
+ save_strategy="epoch",
119
+ load_best_model_at_end=has_val,
120
+ metric_for_best_model="eval_loss" if has_val else None,
121
+ greater_is_better=False,
122
+ report_to="wandb" if args.wandb_project else "none",
123
+ run_name=args.run_name,
124
+ dataloader_num_workers=args.num_workers,
125
+ remove_unused_columns=False,
126
+ )
127
+
128
+ callbacks = [EarlyStoppingCallback(args.patience)] if has_val else []
129
+
130
+ trainer = Trainer(
131
+ model=model,
132
+ args=training_args,
133
+ train_dataset=train_ds,
134
+ eval_dataset=val_ds,
135
+ data_collator=BulkCollator(),
136
+ compute_metrics=compute_metrics if has_val else None,
137
+ callbacks=callbacks,
138
+ )
139
+
140
+ trainer.train()
141
+ trainer.save_model(args.output_dir)
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()