warming666 commited on
Commit
435caaf
·
verified ·
1 Parent(s): 0cadf02

Update modeling_scdiva.py

Browse files
Files changed (1) hide show
  1. modeling_scdiva.py +39 -292
modeling_scdiva.py CHANGED
@@ -1,298 +1,45 @@
1
  """
2
- ScDiVa: A Foundation Model for Single-cell Genomics
3
- Model Architecture Definition
4
-
5
- This file contains the core architecture definition of ScDiVa.
6
- It allows loading pre-trained weights for inference.
7
  """
8
-
9
  import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from typing import Optional, Dict, Tuple, Union
13
- import math
14
- import os
15
-
16
- class ScDiVaConfig:
17
- def __init__(
18
- self,
19
- num_genes: int = 41818, # Updated to match paper (Table 4)
20
- hidden_size: int = 512,
21
- num_hidden_layers: int = 12,
22
- num_attention_heads: int = 8,
23
- intermediate_size: int = 2048,
24
- hidden_dropout_prob: float = 0.1,
25
- attention_probs_dropout_prob: float = 0.1,
26
- max_position_embeddings: int = 1200,
27
- layer_norm_eps: float = 1e-5,
28
- latent_dim: int = 128,
29
- num_cell_types: int = 100,
30
- use_variational: bool = True,
31
- **kwargs
32
- ):
33
- self.num_genes = num_genes
34
- self.hidden_size = hidden_size
35
- self.num_hidden_layers = num_hidden_layers
36
- self.num_attention_heads = num_attention_heads
37
- self.intermediate_size = intermediate_size
38
- self.hidden_dropout_prob = hidden_dropout_prob
39
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
40
- self.max_position_embeddings = max_position_embeddings
41
- self.layer_norm_eps = layer_norm_eps
42
- self.latent_dim = latent_dim
43
- self.num_cell_types = num_cell_types
44
- self.use_variational = use_variational
45
-
46
- class GeneEmbedding(nn.Module):
47
- def __init__(self, config: ScDiVaConfig):
48
- super().__init__()
49
- self.gene_projection = nn.Linear(config.num_genes, config.hidden_size)
50
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
51
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
-
53
- def forward(self, gene_expression: torch.Tensor) -> torch.Tensor:
54
- embeddings = self.gene_projection(gene_expression)
55
- embeddings = self.layer_norm(embeddings)
56
- embeddings = self.dropout(embeddings)
57
- return embeddings
58
-
59
- class MultiHeadAttention(nn.Module):
60
- def __init__(self, config: ScDiVaConfig):
61
- super().__init__()
62
- self.num_attention_heads = config.num_attention_heads
63
- self.attention_head_size = config.hidden_size // config.num_attention_heads
64
- self.all_head_size = self.num_attention_heads * self.attention_head_size
65
-
66
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
67
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
68
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
69
-
70
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
71
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
72
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
-
74
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
75
- new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
76
- x = x.view(*new_shape)
77
- return x.permute(0, 2, 1, 3)
78
-
79
- def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
80
- query_layer = self.transpose_for_scores(self.query(hidden_states))
81
- key_layer = self.transpose_for_scores(self.key(hidden_states))
82
- value_layer = self.transpose_for_scores(self.value(hidden_states))
83
-
84
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
85
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
86
-
87
- if attention_mask is not None:
88
- attention_scores = attention_scores + attention_mask
89
-
90
- attention_probs = F.softmax(attention_scores, dim=-1)
91
- attention_probs = self.dropout(attention_probs)
92
-
93
- context_layer = torch.matmul(attention_probs, value_layer)
94
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
95
- new_shape = context_layer.size()[:-2] + (self.all_head_size,)
96
- context_layer = context_layer.view(*new_shape)
97
-
98
- attention_output = self.dense(context_layer)
99
- attention_output = self.dropout(attention_output)
100
- attention_output = self.layer_norm(attention_output + hidden_states)
101
-
102
- return attention_output
103
-
104
- class FeedForward(nn.Module):
105
- def __init__(self, config: ScDiVaConfig):
106
- super().__init__()
107
- self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
108
- self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
109
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
110
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
111
-
112
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
113
- residual = hidden_states
114
- hidden_states = self.dense1(hidden_states)
115
- hidden_states = F.gelu(hidden_states)
116
- hidden_states = self.dense2(hidden_states)
117
- hidden_states = self.dropout(hidden_states)
118
- hidden_states = self.layer_norm(hidden_states + residual)
119
- return hidden_states
120
-
121
- class TransformerLayer(nn.Module):
122
- def __init__(self, config: ScDiVaConfig):
123
- super().__init__()
124
- self.attention = MultiHeadAttention(config)
125
- self.feed_forward = FeedForward(config)
126
-
127
- def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
128
- attention_output = self.attention(hidden_states, attention_mask)
129
- layer_output = self.feed_forward(attention_output)
130
- return layer_output
131
-
132
- class TransformerEncoder(nn.Module):
133
- def __init__(self, config: ScDiVaConfig):
134
- super().__init__()
135
- self.layers = nn.ModuleList([
136
- TransformerLayer(config) for _ in range(config.num_hidden_layers)
137
- ])
138
-
139
- def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
140
- for layer in self.layers:
141
- hidden_states = layer(hidden_states, attention_mask)
142
- return hidden_states
143
-
144
- class VariationalLayer(nn.Module):
145
- def __init__(self, config: ScDiVaConfig):
146
- super().__init__()
147
- self.mu_projection = nn.Linear(config.hidden_size, config.latent_dim)
148
- self.logvar_projection = nn.Linear(config.hidden_size, config.latent_dim)
149
-
150
- def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
151
- std = torch.exp(0.5 * logvar)
152
- eps = torch.randn_like(std)
153
- return mu + eps * std
154
-
155
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
156
- mu = self.mu_projection(hidden_states)
157
- logvar = self.logvar_projection(hidden_states)
158
- z = self.reparameterize(mu, logvar)
159
- return z, mu, logvar
160
 
161
- class AnnotationHead(nn.Module):
162
- def __init__(self, config: ScDiVaConfig):
163
- super().__init__()
164
- self.dense = nn.Linear(config.latent_dim, config.hidden_size)
165
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
166
- self.classifier = nn.Linear(config.hidden_size, config.num_cell_types)
167
-
168
- def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
169
- hidden = F.gelu(self.dense(latent_representation))
170
- hidden = self.dropout(hidden)
171
- logits = self.classifier(hidden)
172
- return logits
173
-
174
- class BatchIntegrationHead(nn.Module):
175
- def __init__(self, config: ScDiVaConfig):
176
- super().__init__()
177
- self.dense = nn.Linear(config.latent_dim, config.hidden_size)
178
- self.decoder = nn.Linear(config.hidden_size, config.num_genes)
179
-
180
- def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
181
- hidden = F.gelu(self.dense(latent_representation))
182
- reconstructed = self.decoder(hidden)
183
- return reconstructed
184
-
185
- class ScDiVaModel(nn.Module):
186
- """
187
- ScDiVa: Single-cell Deep Variational Analysis Model
188
- """
189
- def __init__(self, config: ScDiVaConfig):
190
- super().__init__()
191
- self.config = config
192
- self.gene_embedding = GeneEmbedding(config)
193
- self.encoder = TransformerEncoder(config)
194
- self.variational_layer = VariationalLayer(config)
195
- self.annotation_head = AnnotationHead(config)
196
- self.batch_integration_head = BatchIntegrationHead(config)
197
-
198
- def encode(self, gene_expression: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
199
- """
200
- Input Shape: (batch_size, num_genes)
201
- Returns: Dict containing latent, mu, logvar
202
- """
203
- embeddings = self.gene_embedding(gene_expression)
204
- embeddings = embeddings.unsqueeze(1) # (B, 1, H)
205
- encoded = self.encoder(embeddings, attention_mask) # (B, 1, H)
206
- encoded = encoded.squeeze(1) # (B, H)
207
- z, mu, logvar = self.variational_layer(encoded)
208
- return {"latent": z, "mu": mu, "logvar": logvar}
209
-
210
- def predict(self, gene_expression: torch.Tensor, task: str = "annotation", attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
211
- """
212
- Inference interface:
213
- - task="annotation": returns classification logits
214
- - task="batch_integration": returns reconstructed expression
215
- """
216
- encoding = self.encode(gene_expression, attention_mask)
217
- latent = encoding["latent"]
218
- if task == "annotation":
219
- return self.annotation_head(latent)
220
- elif task == "batch_integration":
221
- return self.batch_integration_head(latent)
222
  else:
223
- raise ValueError(f"Unknown task: {task}")
224
-
225
- @classmethod
226
- def from_pretrained(
227
- cls,
228
- model_name_or_path: str,
229
- map_location: Optional[str] = None,
230
- strict: bool = True,
231
- use_auth_token: Optional[str] = None,
232
- ) -> "ScDiVaModel":
233
- """
234
- Load pre-trained model from local path or Hugging Face Hub.
235
- Supports directly loading from 'warming666/ScDiVa'.
236
- """
237
- config = ScDiVaConfig()
238
- model = cls(config)
239
-
240
- if map_location is None:
241
- map_location = "cpu"
242
-
243
- ckpt_path: Optional[str] = None
244
-
245
- # 1. Try Local File
246
- if os.path.exists(model_name_or_path):
247
- if os.path.isfile(model_name_or_path):
248
- ckpt_path = model_name_or_path
249
- elif os.path.isdir(model_name_or_path):
250
- # Search for typical weights file
251
- for name in ["pytorch_model.bin", "model.safetensors", "model.pt"]:
252
- p = os.path.join(model_name_or_path, name)
253
- if os.path.exists(p):
254
- ckpt_path = p
255
- break
256
-
257
- # 2. Try Hugging Face Hub Download
258
- if ckpt_path is None:
259
- try:
260
- from huggingface_hub import hf_hub_download
261
- print(f"[ScDiVa] Attempting to download weights from HF: {model_name_or_path}")
262
- # Try safetensors first, then bin
263
- try:
264
- ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="model.safetensors", token=use_auth_token)
265
- except:
266
- # Fallback to pytorch_model.bin
267
- try:
268
- ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="pytorch_model.bin", token=use_auth_token)
269
- except:
270
- pass
271
- except ImportError:
272
- print("[ScDiVa] Warning: `huggingface_hub` not installed. Cannot download from HF.")
273
- except Exception as e:
274
- print(f"[ScDiVa] Warning: HF download error (check network/repo ID): {e}")
275
-
276
- # 3. Load or Fallback to Random Init (Demo Mode)
277
- if ckpt_path is None:
278
- print(f"[ScDiVa] Warning: No weights found at '{model_name_or_path}'. Using random initialization (DEMO MODE).")
279
- return model
280
-
281
- print(f"[ScDiVa] Loading weights from {ckpt_path}...")
282
- try:
283
- state = torch.load(ckpt_path, map_location=map_location)
284
- # Support both raw state_dict and dictionary containing state_dict
285
- state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state
286
 
287
- missing, unexpected = model.load_state_dict(state_dict, strict=strict)
288
- if missing:
289
- print(f"[ScDiVa] Missing keys: {len(missing)}")
290
- if unexpected:
291
- print(f"[ScDiVa] Unexpected keys: {len(unexpected)}")
292
- print("✅ Model weights loaded successfully.")
293
-
294
- except Exception as e:
295
- print(f"[ScDiVa] Error loading weights: {e}")
296
- print("[ScDiVa] Model structure initialized with random weights.")
297
-
298
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ ScDiVa Inference SDK
3
+ High-level wrappers for single-cell analysis tasks.
 
 
 
4
  """
 
5
  import torch
6
+ import numpy as np
7
+ from modeling_scdiva import ScDiVaModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ class ScDiVaInference:
10
+ def __init__(self, model_name: str = "warming666/ScDiVa", device: str = None):
11
+ if device is None:
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  else:
14
+ self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ print(f"Initializing ScDiVa on {self.device}...")
17
+ self.model = ScDiVaModel.from_pretrained(model_name)
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+
21
+ def _preprocess(self, adata) -> torch.Tensor:
22
+ # Placeholder for preprocessing (normalization, etc.)
23
+ # In real usage, this aligns genes and converts to tensor
24
+ if hasattr(adata.X, "toarray"):
25
+ expr = adata.X.toarray()
26
+ else:
27
+ expr = adata.X
28
+ return torch.tensor(expr, dtype=torch.float32).to(self.device)
29
+
30
+ def annotate(self, adata):
31
+ data = self._preprocess(adata)
32
+ with torch.no_grad():
33
+ logits = self.model.predict(data, task="annotation")
34
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
35
+ return preds
36
+
37
+ def integrate_batches(self, adata_list):
38
+ # Placeholder for integration logic
39
+ results = []
40
+ for adata in adata_list:
41
+ data = self._preprocess(adata)
42
+ with torch.no_grad():
43
+ emb = self.model.encode(data)["latent"]
44
+ results.append(emb.cpu().numpy())
45
+ return np.concatenate(results, axis=0)