danielle-miller-sayag commited on
Commit
7b4d0f9
·
verified ·
1 Parent(s): 40196cf

initial: weights + modeling code + lean config

Browse files
Files changed (2) hide show
  1. modeling_virtual_cell.py +17 -0
  2. train.py +130 -0
modeling_virtual_cell.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List, Optional
2
 
3
  import torch
@@ -79,6 +94,8 @@ class MLP(nn.Module):
79
 
80
 
81
  class MLPCellEmbedder(nn.Module):
 
 
82
  def __init__(
83
  self,
84
  n_genes: int,
 
1
+ """
2
+ Virtual Cell Patient Model — HuggingFace release.
3
+
4
+ Architecture: PaSCient (Cui et al., 2025). ConvergeBio contribution: training
5
+ recipe, data scale, and model parameters.
6
+
7
+ Usage:
8
+ from transformers import AutoModel
9
+ model = AutoModel.from_pretrained(
10
+ "ConvergeBio/virtual-cell-patient", trust_remote_code=True
11
+ )
12
+ # input_ids: [batch, num_cells, num_genes] float32 log-normalized expression
13
+ out = model(input_ids=x) # out.logits: [batch, num_classes]
14
+ """
15
+
16
  from typing import List, Optional
17
 
18
  import torch
 
94
 
95
 
96
  class MLPCellEmbedder(nn.Module):
97
+ # Thin wrapper that preserves the .encoder attribute name required
98
+ # for state-dict key compatibility with the checkpoint.
99
  def __init__(
100
  self,
101
  n_genes: int,
train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional
6
+
7
+ import torch
8
+ from datasets import load_dataset
9
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
10
+
11
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
12
+ from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel
13
+
14
+
15
+ @dataclass
16
+ class PatientCollator:
17
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
18
+ return {
19
+ "input_ids": torch.stack([
20
+ torch.tensor(f["input_ids"], dtype=torch.float32) for f in features
21
+ ]),
22
+ "attention_mask": torch.stack([
23
+ torch.tensor(f["attention_mask"], dtype=torch.bool) for f in features
24
+ ]),
25
+ "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long),
26
+ "entity_id": torch.tensor([f["entity_id"] for f in features], dtype=torch.long),
27
+ }
28
+
29
+
30
+ class PatientTrainer(Trainer):
31
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
32
+ outputs = model(**inputs)
33
+ return (outputs.loss, outputs) if return_outputs else outputs.loss
34
+
35
+
36
+ def parse_args():
37
+ p = argparse.ArgumentParser()
38
+
39
+ p.add_argument("--dataset_path", required=True,
40
+ help="HF dataset ID or local path with train (and optionally validation) splits")
41
+ p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-patient")
42
+ p.add_argument("--hf_token", default=None)
43
+ p.add_argument("--output_dir", default="./vc_output")
44
+ p.add_argument("--from_scratch", action="store_true")
45
+ p.add_argument("--freeze_embedder", action="store_true")
46
+ p.add_argument("--num_classes", type=int, default=None)
47
+ p.add_argument("--num_train_epochs", type=int, default=15)
48
+ p.add_argument("--per_device_train_batch_size", type=int, default=32)
49
+ p.add_argument("--per_device_eval_batch_size", type=int, default=32)
50
+ p.add_argument("--learning_rate", type=float, default=1e-4)
51
+ p.add_argument("--weight_decay", type=float, default=0.05)
52
+ p.add_argument("--warmup_ratio", type=float, default=0.1)
53
+ p.add_argument("--lr_scheduler_type", default="cosine")
54
+ p.add_argument("--patience", type=int, default=5)
55
+ p.add_argument("--num_workers", type=int, default=4)
56
+ p.add_argument("--wandb_project", default=None)
57
+ p.add_argument("--run_name", default=None)
58
+
59
+ return p.parse_args()
60
+
61
+
62
+ def main():
63
+ args = parse_args()
64
+
65
+ ds = load_dataset(args.dataset_path)
66
+ train_ds = ds["train"]
67
+ val_ds: Optional[object] = ds.get("validation")
68
+
69
+ hf_kwargs = {"trust_remote_code": True}
70
+ if args.hf_token:
71
+ hf_kwargs["token"] = args.hf_token
72
+
73
+ config = VirtualCellPatientConfig.from_pretrained(args.model_name_or_path, **hf_kwargs)
74
+ if args.num_classes is not None:
75
+ config.num_classes = args.num_classes
76
+ config.id2label = {str(i): str(i) for i in range(args.num_classes)}
77
+ config.label2id = {str(i): i for i in range(args.num_classes)}
78
+
79
+ if args.from_scratch:
80
+ model = VirtualCellPatientModel(config)
81
+ else:
82
+ model = VirtualCellPatientModel.from_pretrained(
83
+ args.model_name_or_path, config=config, **hf_kwargs
84
+ )
85
+
86
+ if args.freeze_embedder:
87
+ for param in model.patient_embedder.parameters():
88
+ param.requires_grad = False
89
+
90
+ if args.wandb_project:
91
+ os.environ["WANDB_PROJECT"] = args.wandb_project
92
+
93
+ has_val = val_ds is not None
94
+ training_args = TrainingArguments(
95
+ output_dir=args.output_dir,
96
+ num_train_epochs=args.num_train_epochs,
97
+ per_device_train_batch_size=args.per_device_train_batch_size,
98
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
99
+ learning_rate=args.learning_rate,
100
+ weight_decay=args.weight_decay,
101
+ warmup_ratio=args.warmup_ratio,
102
+ lr_scheduler_type=args.lr_scheduler_type,
103
+ eval_strategy="epoch" if has_val else "no",
104
+ save_strategy="epoch",
105
+ load_best_model_at_end=has_val,
106
+ metric_for_best_model="eval_loss" if has_val else None,
107
+ greater_is_better=False,
108
+ report_to="wandb" if args.wandb_project else "none",
109
+ run_name=args.run_name,
110
+ dataloader_num_workers=args.num_workers,
111
+ remove_unused_columns=False,
112
+ )
113
+
114
+ callbacks = [EarlyStoppingCallback(args.patience)] if has_val else []
115
+
116
+ trainer = PatientTrainer(
117
+ model=model,
118
+ args=training_args,
119
+ train_dataset=train_ds,
120
+ eval_dataset=val_ds,
121
+ data_collator=PatientCollator(),
122
+ callbacks=callbacks,
123
+ )
124
+
125
+ trainer.train()
126
+ trainer.save_model(args.output_dir)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()