danielle-miller-sayag commited on
Commit
6e5b188
·
verified ·
1 Parent(s): 4ff7bcb

initial: weights + modeling code + lean config

Browse files
__pycache__/modeling_virtual_cell.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "virtual_cell_patient",
3
+ "architectures": ["VirtualCellPatientModel"],
4
+ "auto_map": {
5
+ "AutoConfig": "modeling_virtual_cell.VirtualCellPatientConfig",
6
+ "AutoModel": "modeling_virtual_cell.VirtualCellPatientModel"
7
+ },
8
+ "n_genes": 18301,
9
+ "embed_dim": 512,
10
+ "hidden_dim": [4096, 1024],
11
+ "dropout": 0.1,
12
+ "residual": false,
13
+ "activation": "prelu",
14
+ "attention_hidden_dim": 512,
15
+ "num_classes": 10,
16
+ "classifier_dropout": 0.1,
17
+ "id2label": {
18
+ "0": "oncological",
19
+ "1": "immune_inflammatory",
20
+ "2": "neurological",
21
+ "3": "metabolic_vascular",
22
+ "4": "gastrointestinal",
23
+ "5": "respiratory",
24
+ "6": "epithelial_barrier",
25
+ "7": "sensory_specialized",
26
+ "8": "healthy_control",
27
+ "9": "other"
28
+ },
29
+ "label2id": {
30
+ "oncological": 0,
31
+ "immune_inflammatory": 1,
32
+ "neurological": 2,
33
+ "metabolic_vascular": 3,
34
+ "gastrointestinal": 4,
35
+ "respiratory": 5,
36
+ "epithelial_barrier": 6,
37
+ "sensory_specialized": 7,
38
+ "healthy_control": 8,
39
+ "other": 9
40
+ }
41
+ }
gene_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59e8a629ccefe64f5d4123b7ef20d470e7c9197e8c4eb2564528e4df95d30a7d
3
+ size 319898572
modeling_virtual_cell.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Virtual Cell Patient Model — HuggingFace release.
3
+
4
+ Architecture: PaSCient (Cui et al., 2025). ConvergeBio contribution: training
5
+ recipe, data scale, and model parameters.
6
+
7
+ Usage:
8
+ from transformers import AutoModel
9
+ model = AutoModel.from_pretrained(
10
+ "ConvergeBio/virtual-cell-patient", trust_remote_code=True
11
+ )
12
+ # input_ids: [batch, num_cells, num_genes] float32 log-normalized expression
13
+ out = model(input_ids=x) # out.logits: [batch, num_classes]
14
+ """
15
+
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import SequenceClassifierOutput
23
+
24
+
25
+ def _get_activation(activation: str) -> nn.Module:
26
+ if activation == "prelu":
27
+ return nn.PReLU()
28
+ elif activation == "relu":
29
+ return nn.ReLU()
30
+ elif activation == "gelu":
31
+ return nn.GELU()
32
+ elif activation == "tanh":
33
+ return nn.Tanh()
34
+ raise ValueError(f"Unsupported activation: {activation!r}")
35
+
36
+
37
+ class MLP(nn.Module):
38
+ def __init__(
39
+ self,
40
+ input_dim: int,
41
+ output_dim: int = 128,
42
+ hidden_dim: Optional[List[int]] = None,
43
+ dropout: float = 0.0,
44
+ residual: bool = False,
45
+ activation: str = "prelu",
46
+ ):
47
+ super().__init__()
48
+ if hidden_dim is None:
49
+ hidden_dim = [1024, 1024]
50
+ self.input_dim = input_dim
51
+ self.latent_dim = output_dim
52
+ self.residual = residual
53
+ self.dropout = dropout
54
+ self.activation = activation
55
+ self.network = nn.ModuleList()
56
+
57
+ if residual:
58
+ assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal"
59
+
60
+ for i in range(len(hidden_dim)):
61
+ if i == 0:
62
+ self.network.append(
63
+ nn.Sequential(
64
+ nn.Linear(input_dim, hidden_dim[i]),
65
+ nn.BatchNorm1d(hidden_dim[i]),
66
+ _get_activation(activation),
67
+ )
68
+ )
69
+ else:
70
+ self.network.append(
71
+ nn.Sequential(
72
+ nn.Dropout(p=dropout),
73
+ nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
74
+ nn.BatchNorm1d(hidden_dim[i]),
75
+ _get_activation(activation),
76
+ )
77
+ )
78
+ self.network.append(nn.Linear(hidden_dim[-1], output_dim))
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ assert torch.is_tensor(x) and x.ndim == 2, (
82
+ f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}"
83
+ )
84
+ assert x.shape[0] > 1, (
85
+ f"BatchNorm requires batch size > 1, got {x.shape[0]}. "
86
+ "Use model.eval() for single-sample inference."
87
+ )
88
+ for i, layer in enumerate(self.network):
89
+ if self.residual and 0 < i < len(self.network) - 1:
90
+ x = layer(x) + x
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class MLPCellEmbedder(nn.Module):
97
+ # Thin wrapper that preserves the .encoder attribute name required
98
+ # for state-dict key compatibility with the checkpoint.
99
+ def __init__(
100
+ self,
101
+ n_genes: int,
102
+ output_dim: int = 128,
103
+ hidden_dim: Optional[List[int]] = None,
104
+ dropout: float = 0.1,
105
+ residual: bool = False,
106
+ activation: str = "prelu",
107
+ ):
108
+ super().__init__()
109
+ if hidden_dim is None:
110
+ hidden_dim = [1024, 1024]
111
+ self.n_genes = n_genes
112
+ self.output_dim = output_dim
113
+ self.encoder = MLP(
114
+ input_dim=n_genes,
115
+ output_dim=output_dim,
116
+ hidden_dim=hidden_dim,
117
+ dropout=dropout,
118
+ residual=residual,
119
+ activation=activation,
120
+ )
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ assert torch.is_tensor(x) and x.ndim == 2, (
124
+ f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}"
125
+ )
126
+ return self.encoder(x)
127
+
128
+
129
+ class AttentionAggregator(nn.Module):
130
+ def __init__(self, embedding_dim: int, hidden_dim: int = 128):
131
+ super().__init__()
132
+ self.attention_net = nn.Sequential(
133
+ nn.Linear(embedding_dim, hidden_dim),
134
+ nn.ReLU(),
135
+ nn.Linear(hidden_dim, 1),
136
+ )
137
+
138
+ def aggregate(
139
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
140
+ ) -> torch.Tensor:
141
+ """
142
+ Args:
143
+ x: [batch, num_cells, embedding_dim]
144
+ mask: [batch, num_cells] — 1=valid, 0=ignore (optional)
145
+ Returns:
146
+ [batch, embedding_dim]
147
+ """
148
+ if mask is not None:
149
+ assert mask.sum(dim=1).min() > 0, "All samples must have at least one valid cell"
150
+ scores = self.attention_net(x).squeeze(-1)
151
+ if mask is not None:
152
+ scores = scores.masked_fill(mask == 0, float("-inf"))
153
+ weights = torch.softmax(scores, dim=1).unsqueeze(-1)
154
+ return (x * weights).sum(dim=1)
155
+
156
+
157
+ class PatientEmbedder(nn.Module):
158
+ def __init__(self, cell_embedder: nn.Module, aggregator: nn.Module):
159
+ super().__init__()
160
+ self.cell_embedder = cell_embedder
161
+ self.aggregator = aggregator
162
+
163
+ def forward(
164
+ self, cell_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None
165
+ ) -> torch.Tensor:
166
+ """
167
+ Args:
168
+ cell_matrix: [batch, num_cells, num_genes]
169
+ mask: [batch, num_cells] — optional
170
+ Returns:
171
+ [batch, embedding_dim]
172
+ """
173
+ batch_size, num_cells, num_genes = cell_matrix.shape
174
+ flat = cell_matrix.view(-1, num_genes)
175
+ embeddings_flat = self.cell_embedder(flat)
176
+ embeddings = embeddings_flat.view(batch_size, num_cells, -1)
177
+ return self.aggregator.aggregate(embeddings, mask)
178
+
179
+ def get_embedding_dim(self) -> int:
180
+ return self.cell_embedder.output_dim
181
+
182
+
183
+ class CrossEntropyLossViews(nn.Module):
184
+ """Cross-entropy loss that averages per-entity (patient) across augmented views."""
185
+
186
+ def __init__(self, class_weights: Optional[torch.Tensor] = None):
187
+ super().__init__()
188
+ self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, reduction="none")
189
+
190
+ def forward(
191
+ self,
192
+ predictions: torch.Tensor,
193
+ labels: torch.Tensor,
194
+ entity_ids: Optional[torch.Tensor] = None,
195
+ ) -> torch.Tensor:
196
+ sample_losses = self.ce_loss(predictions, labels)
197
+ if entity_ids is None:
198
+ return torch.mean(sample_losses)
199
+ unique_entities, inverse_indices, counts = torch.unique(
200
+ entity_ids, return_inverse=True, return_counts=True
201
+ )
202
+ entity_sums = torch.zeros(
203
+ len(unique_entities), device=sample_losses.device, dtype=sample_losses.dtype
204
+ )
205
+ entity_sums.scatter_add_(0, inverse_indices, sample_losses)
206
+ return torch.mean(entity_sums / counts.float())
207
+
208
+
209
+ class VirtualCellPatientConfig(PretrainedConfig):
210
+ model_type = "virtual_cell_patient"
211
+
212
+ def __init__(
213
+ self,
214
+ n_genes: int = 18301,
215
+ embed_dim: int = 512,
216
+ hidden_dim: Optional[List[int]] = None,
217
+ dropout: float = 0.1,
218
+ residual: bool = False,
219
+ activation: str = "prelu",
220
+ attention_hidden_dim: int = 512,
221
+ num_classes: int = 10,
222
+ classifier_dropout: float = 0.1,
223
+ **kwargs,
224
+ ):
225
+ super().__init__(**kwargs)
226
+ self.n_genes = n_genes
227
+ self.embed_dim = embed_dim
228
+ self.hidden_dim = hidden_dim if hidden_dim is not None else [4096, 1024]
229
+ self.dropout = dropout
230
+ self.residual = residual
231
+ self.activation = activation
232
+ self.attention_hidden_dim = attention_hidden_dim
233
+ self.num_classes = num_classes
234
+ self.classifier_dropout = classifier_dropout
235
+
236
+
237
+ class VirtualCellPatientModel(PreTrainedModel):
238
+ config_class = VirtualCellPatientConfig
239
+
240
+ def __init__(self, config: VirtualCellPatientConfig):
241
+ super().__init__(config)
242
+ cell_embedder = MLPCellEmbedder(
243
+ n_genes=config.n_genes,
244
+ output_dim=config.embed_dim,
245
+ hidden_dim=config.hidden_dim,
246
+ dropout=config.dropout,
247
+ residual=config.residual,
248
+ activation=config.activation,
249
+ )
250
+ aggregator = AttentionAggregator(
251
+ embedding_dim=config.embed_dim,
252
+ hidden_dim=config.attention_hidden_dim,
253
+ )
254
+ self.patient_embedder = PatientEmbedder(cell_embedder, aggregator)
255
+ self.classifier = nn.Sequential(
256
+ nn.Dropout(config.classifier_dropout),
257
+ nn.Linear(config.embed_dim, config.num_classes),
258
+ )
259
+ self.loss_fn = CrossEntropyLossViews()
260
+
261
+ def _init_weights(self, module):
262
+ pass
263
+
264
+ def forward(
265
+ self,
266
+ input_ids: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ labels: Optional[torch.Tensor] = None,
269
+ entity_id: Optional[torch.Tensor] = None,
270
+ **kwargs,
271
+ ) -> SequenceClassifierOutput:
272
+ """
273
+ Args:
274
+ input_ids: [batch, num_cells, num_genes] log-normalized float32 expression
275
+ attention_mask: [batch, num_cells] 1=valid, 0=ignore (optional)
276
+ labels: [batch] integer class indices (optional, for loss)
277
+ entity_id: [batch] patient IDs grouping augmented views (optional)
278
+ Returns:
279
+ SequenceClassifierOutput with .loss (when labels given) and .logits [batch, num_classes]
280
+ """
281
+ embeddings = self.patient_embedder(input_ids, attention_mask)
282
+ logits = self.classifier(embeddings)
283
+
284
+ loss = None
285
+ if labels is not None:
286
+ loss = (
287
+ self.loss_fn(logits, labels, entity_id)
288
+ if entity_id is not None
289
+ else F.cross_entropy(logits, labels)
290
+ )
291
+
292
+ return SequenceClassifierOutput(loss=loss, logits=logits)