warming666 commited on
Commit
0cadf02
·
verified ·
1 Parent(s): 0031d9e

Create modeling_scdiva.py

Browse files
Files changed (1) hide show
  1. modeling_scdiva.py +298 -0
modeling_scdiva.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ScDiVa: A Foundation Model for Single-cell Genomics
3
+ Model Architecture Definition
4
+
5
+ This file contains the core architecture definition of ScDiVa.
6
+ It allows loading pre-trained weights for inference.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional, Dict, Tuple, Union
13
+ import math
14
+ import os
15
+
16
+ class ScDiVaConfig:
17
+ def __init__(
18
+ self,
19
+ num_genes: int = 41818, # Updated to match paper (Table 4)
20
+ hidden_size: int = 512,
21
+ num_hidden_layers: int = 12,
22
+ num_attention_heads: int = 8,
23
+ intermediate_size: int = 2048,
24
+ hidden_dropout_prob: float = 0.1,
25
+ attention_probs_dropout_prob: float = 0.1,
26
+ max_position_embeddings: int = 1200,
27
+ layer_norm_eps: float = 1e-5,
28
+ latent_dim: int = 128,
29
+ num_cell_types: int = 100,
30
+ use_variational: bool = True,
31
+ **kwargs
32
+ ):
33
+ self.num_genes = num_genes
34
+ self.hidden_size = hidden_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+ self.intermediate_size = intermediate_size
38
+ self.hidden_dropout_prob = hidden_dropout_prob
39
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
40
+ self.max_position_embeddings = max_position_embeddings
41
+ self.layer_norm_eps = layer_norm_eps
42
+ self.latent_dim = latent_dim
43
+ self.num_cell_types = num_cell_types
44
+ self.use_variational = use_variational
45
+
46
+ class GeneEmbedding(nn.Module):
47
+ def __init__(self, config: ScDiVaConfig):
48
+ super().__init__()
49
+ self.gene_projection = nn.Linear(config.num_genes, config.hidden_size)
50
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+
53
+ def forward(self, gene_expression: torch.Tensor) -> torch.Tensor:
54
+ embeddings = self.gene_projection(gene_expression)
55
+ embeddings = self.layer_norm(embeddings)
56
+ embeddings = self.dropout(embeddings)
57
+ return embeddings
58
+
59
+ class MultiHeadAttention(nn.Module):
60
+ def __init__(self, config: ScDiVaConfig):
61
+ super().__init__()
62
+ self.num_attention_heads = config.num_attention_heads
63
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
64
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
65
+
66
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
67
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
68
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
69
+
70
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
71
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
72
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+
74
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
75
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
76
+ x = x.view(*new_shape)
77
+ return x.permute(0, 2, 1, 3)
78
+
79
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
80
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
81
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
82
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
83
+
84
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
85
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
86
+
87
+ if attention_mask is not None:
88
+ attention_scores = attention_scores + attention_mask
89
+
90
+ attention_probs = F.softmax(attention_scores, dim=-1)
91
+ attention_probs = self.dropout(attention_probs)
92
+
93
+ context_layer = torch.matmul(attention_probs, value_layer)
94
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
95
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
96
+ context_layer = context_layer.view(*new_shape)
97
+
98
+ attention_output = self.dense(context_layer)
99
+ attention_output = self.dropout(attention_output)
100
+ attention_output = self.layer_norm(attention_output + hidden_states)
101
+
102
+ return attention_output
103
+
104
+ class FeedForward(nn.Module):
105
+ def __init__(self, config: ScDiVaConfig):
106
+ super().__init__()
107
+ self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
108
+ self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
109
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
110
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
111
+
112
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
113
+ residual = hidden_states
114
+ hidden_states = self.dense1(hidden_states)
115
+ hidden_states = F.gelu(hidden_states)
116
+ hidden_states = self.dense2(hidden_states)
117
+ hidden_states = self.dropout(hidden_states)
118
+ hidden_states = self.layer_norm(hidden_states + residual)
119
+ return hidden_states
120
+
121
+ class TransformerLayer(nn.Module):
122
+ def __init__(self, config: ScDiVaConfig):
123
+ super().__init__()
124
+ self.attention = MultiHeadAttention(config)
125
+ self.feed_forward = FeedForward(config)
126
+
127
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
128
+ attention_output = self.attention(hidden_states, attention_mask)
129
+ layer_output = self.feed_forward(attention_output)
130
+ return layer_output
131
+
132
+ class TransformerEncoder(nn.Module):
133
+ def __init__(self, config: ScDiVaConfig):
134
+ super().__init__()
135
+ self.layers = nn.ModuleList([
136
+ TransformerLayer(config) for _ in range(config.num_hidden_layers)
137
+ ])
138
+
139
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
140
+ for layer in self.layers:
141
+ hidden_states = layer(hidden_states, attention_mask)
142
+ return hidden_states
143
+
144
+ class VariationalLayer(nn.Module):
145
+ def __init__(self, config: ScDiVaConfig):
146
+ super().__init__()
147
+ self.mu_projection = nn.Linear(config.hidden_size, config.latent_dim)
148
+ self.logvar_projection = nn.Linear(config.hidden_size, config.latent_dim)
149
+
150
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
151
+ std = torch.exp(0.5 * logvar)
152
+ eps = torch.randn_like(std)
153
+ return mu + eps * std
154
+
155
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
156
+ mu = self.mu_projection(hidden_states)
157
+ logvar = self.logvar_projection(hidden_states)
158
+ z = self.reparameterize(mu, logvar)
159
+ return z, mu, logvar
160
+
161
+ class AnnotationHead(nn.Module):
162
+ def __init__(self, config: ScDiVaConfig):
163
+ super().__init__()
164
+ self.dense = nn.Linear(config.latent_dim, config.hidden_size)
165
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
166
+ self.classifier = nn.Linear(config.hidden_size, config.num_cell_types)
167
+
168
+ def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
169
+ hidden = F.gelu(self.dense(latent_representation))
170
+ hidden = self.dropout(hidden)
171
+ logits = self.classifier(hidden)
172
+ return logits
173
+
174
+ class BatchIntegrationHead(nn.Module):
175
+ def __init__(self, config: ScDiVaConfig):
176
+ super().__init__()
177
+ self.dense = nn.Linear(config.latent_dim, config.hidden_size)
178
+ self.decoder = nn.Linear(config.hidden_size, config.num_genes)
179
+
180
+ def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
181
+ hidden = F.gelu(self.dense(latent_representation))
182
+ reconstructed = self.decoder(hidden)
183
+ return reconstructed
184
+
185
+ class ScDiVaModel(nn.Module):
186
+ """
187
+ ScDiVa: Single-cell Deep Variational Analysis Model
188
+ """
189
+ def __init__(self, config: ScDiVaConfig):
190
+ super().__init__()
191
+ self.config = config
192
+ self.gene_embedding = GeneEmbedding(config)
193
+ self.encoder = TransformerEncoder(config)
194
+ self.variational_layer = VariationalLayer(config)
195
+ self.annotation_head = AnnotationHead(config)
196
+ self.batch_integration_head = BatchIntegrationHead(config)
197
+
198
+ def encode(self, gene_expression: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
199
+ """
200
+ Input Shape: (batch_size, num_genes)
201
+ Returns: Dict containing latent, mu, logvar
202
+ """
203
+ embeddings = self.gene_embedding(gene_expression)
204
+ embeddings = embeddings.unsqueeze(1) # (B, 1, H)
205
+ encoded = self.encoder(embeddings, attention_mask) # (B, 1, H)
206
+ encoded = encoded.squeeze(1) # (B, H)
207
+ z, mu, logvar = self.variational_layer(encoded)
208
+ return {"latent": z, "mu": mu, "logvar": logvar}
209
+
210
+ def predict(self, gene_expression: torch.Tensor, task: str = "annotation", attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
211
+ """
212
+ Inference interface:
213
+ - task="annotation": returns classification logits
214
+ - task="batch_integration": returns reconstructed expression
215
+ """
216
+ encoding = self.encode(gene_expression, attention_mask)
217
+ latent = encoding["latent"]
218
+ if task == "annotation":
219
+ return self.annotation_head(latent)
220
+ elif task == "batch_integration":
221
+ return self.batch_integration_head(latent)
222
+ else:
223
+ raise ValueError(f"Unknown task: {task}")
224
+
225
+ @classmethod
226
+ def from_pretrained(
227
+ cls,
228
+ model_name_or_path: str,
229
+ map_location: Optional[str] = None,
230
+ strict: bool = True,
231
+ use_auth_token: Optional[str] = None,
232
+ ) -> "ScDiVaModel":
233
+ """
234
+ Load pre-trained model from local path or Hugging Face Hub.
235
+ Supports directly loading from 'warming666/ScDiVa'.
236
+ """
237
+ config = ScDiVaConfig()
238
+ model = cls(config)
239
+
240
+ if map_location is None:
241
+ map_location = "cpu"
242
+
243
+ ckpt_path: Optional[str] = None
244
+
245
+ # 1. Try Local File
246
+ if os.path.exists(model_name_or_path):
247
+ if os.path.isfile(model_name_or_path):
248
+ ckpt_path = model_name_or_path
249
+ elif os.path.isdir(model_name_or_path):
250
+ # Search for typical weights file
251
+ for name in ["pytorch_model.bin", "model.safetensors", "model.pt"]:
252
+ p = os.path.join(model_name_or_path, name)
253
+ if os.path.exists(p):
254
+ ckpt_path = p
255
+ break
256
+
257
+ # 2. Try Hugging Face Hub Download
258
+ if ckpt_path is None:
259
+ try:
260
+ from huggingface_hub import hf_hub_download
261
+ print(f"[ScDiVa] Attempting to download weights from HF: {model_name_or_path}")
262
+ # Try safetensors first, then bin
263
+ try:
264
+ ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="model.safetensors", token=use_auth_token)
265
+ except:
266
+ # Fallback to pytorch_model.bin
267
+ try:
268
+ ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="pytorch_model.bin", token=use_auth_token)
269
+ except:
270
+ pass
271
+ except ImportError:
272
+ print("[ScDiVa] Warning: `huggingface_hub` not installed. Cannot download from HF.")
273
+ except Exception as e:
274
+ print(f"[ScDiVa] Warning: HF download error (check network/repo ID): {e}")
275
+
276
+ # 3. Load or Fallback to Random Init (Demo Mode)
277
+ if ckpt_path is None:
278
+ print(f"[ScDiVa] Warning: No weights found at '{model_name_or_path}'. Using random initialization (DEMO MODE).")
279
+ return model
280
+
281
+ print(f"[ScDiVa] Loading weights from {ckpt_path}...")
282
+ try:
283
+ state = torch.load(ckpt_path, map_location=map_location)
284
+ # Support both raw state_dict and dictionary containing state_dict
285
+ state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state
286
+
287
+ missing, unexpected = model.load_state_dict(state_dict, strict=strict)
288
+ if missing:
289
+ print(f"[ScDiVa] Missing keys: {len(missing)}")
290
+ if unexpected:
291
+ print(f"[ScDiVa] Unexpected keys: {len(unexpected)}")
292
+ print("✅ Model weights loaded successfully.")
293
+
294
+ except Exception as e:
295
+ print(f"[ScDiVa] Error loading weights: {e}")
296
+ print("[ScDiVa] Model structure initialized with random weights.")
297
+
298
+ return model