warming666 commited on
Commit
3ab82d7
·
verified ·
1 Parent(s): 435caaf

Update modeling_scdiva.py

Browse files
Files changed (1) hide show
  1. modeling_scdiva.py +310 -39
modeling_scdiva.py CHANGED
@@ -1,45 +1,316 @@
1
  """
2
- ScDiVa Inference SDK
3
- High-level wrappers for single-cell analysis tasks.
 
 
 
4
  """
 
5
  import torch
6
- import numpy as np
7
- from modeling_scdiva import ScDiVaModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- class ScDiVaInference:
10
- def __init__(self, model_name: str = "warming666/ScDiVa", device: str = None):
11
- if device is None:
12
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  else:
14
- self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- print(f"Initializing ScDiVa on {self.device}...")
17
- self.model = ScDiVaModel.from_pretrained(model_name)
18
- self.model.to(self.device)
19
- self.model.eval()
20
-
21
- def _preprocess(self, adata) -> torch.Tensor:
22
- # Placeholder for preprocessing (normalization, etc.)
23
- # In real usage, this aligns genes and converts to tensor
24
- if hasattr(adata.X, "toarray"):
25
- expr = adata.X.toarray()
26
- else:
27
- expr = adata.X
28
- return torch.tensor(expr, dtype=torch.float32).to(self.device)
29
-
30
- def annotate(self, adata):
31
- data = self._preprocess(adata)
32
- with torch.no_grad():
33
- logits = self.model.predict(data, task="annotation")
34
- preds = torch.argmax(logits, dim=1).cpu().numpy()
35
- return preds
36
-
37
- def integrate_batches(self, adata_list):
38
- # Placeholder for integration logic
39
- results = []
40
- for adata in adata_list:
41
- data = self._preprocess(adata)
42
- with torch.no_grad():
43
- emb = self.model.encode(data)["latent"]
44
- results.append(emb.cpu().numpy())
45
- return np.concatenate(results, axis=0)
 
1
  """
2
+ ScDiVa: A Foundation Model for Single-cell Genomics
3
+ Model Architecture Definition
4
+
5
+ This file contains the core architecture definition of ScDiVa.
6
+ It integrates SwiGLU, RoPE, and RMSNorm as described in the paper.
7
  """
8
+
9
  import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional, Dict, Tuple, Union
13
+ import math
14
+ import os
15
+
16
+ class ScDiVaConfig:
17
+ def __init__(
18
+ self,
19
+ num_genes: int = 41818,
20
+ hidden_size: int = 512,
21
+ num_hidden_layers: int = 12,
22
+ num_attention_heads: int = 8,
23
+ intermediate_size: int = 2048,
24
+ hidden_dropout_prob: float = 0.1,
25
+ attention_probs_dropout_prob: float = 0.1,
26
+ max_position_embeddings: int = 1200,
27
+ layer_norm_eps: float = 1e-5,
28
+ latent_dim: int = 128,
29
+ num_cell_types: int = 100,
30
+ use_variational: bool = True,
31
+ rope_theta: float = 10000.0,
32
+ **kwargs
33
+ ):
34
+ self.num_genes = num_genes
35
+ self.hidden_size = hidden_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+ self.intermediate_size = intermediate_size
39
+ self.hidden_dropout_prob = hidden_dropout_prob
40
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.layer_norm_eps = layer_norm_eps
43
+ self.latent_dim = latent_dim
44
+ self.num_cell_types = num_cell_types
45
+ self.use_variational = use_variational
46
+ self.rope_theta = rope_theta
47
+
48
+ # =============================================================================
49
+ # Core Blocks (Adapted from blocks.py to match Paper)
50
+ # =============================================================================
51
+
52
+ class RMSNorm(nn.Module):
53
+ def __init__(self, dim: int, eps: float = 1e-5):
54
+ super().__init__()
55
+ self.eps = eps
56
+ self.weight = nn.Parameter(torch.ones(dim))
57
+
58
+ def forward(self, x):
59
+ x_float = x.float()
60
+ output = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
61
+ return (output * self.weight.float()).type_as(x)
62
+
63
+ class SwiGLU(nn.Module):
64
+ def __init__(self, dim: int, hidden_dim: int):
65
+ super().__init__()
66
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
67
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
68
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
69
+
70
+ def forward(self, x):
71
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
72
+
73
+ class RotaryEmbedding(nn.Module):
74
+ def __init__(self, dim, max_seq_len=4096, base=10000.0):
75
+ super().__init__()
76
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
77
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
78
+ self.max_seq_len = max_seq_len
79
+
80
+ def forward(self, x, seq_len=None):
81
+ if seq_len is None:
82
+ seq_len = x.shape[1]
83
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
84
+ freqs = torch.outer(t, self.inv_freq)
85
+ emb = torch.cat((freqs, freqs), dim=-1)
86
+ return emb.cos()[None, :, :], emb.sin()[None, :, :]
87
+
88
+ def apply_rotary_pos_emb(q, k, cos, sin):
89
+ # Helper to apply rotation
90
+ def rotate_half(x):
91
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
92
+ return torch.cat((-x2, x1), dim=-1)
93
+
94
+ # Reshape cos/sin for broadcasting: [1, seq_len, 1, head_dim]
95
+ cos = cos.unsqueeze(2)
96
+ sin = sin.unsqueeze(2)
97
+ q_embed = (q * cos) + (rotate_half(q) * sin)
98
+ k_embed = (k * cos) + (rotate_half(k) * sin)
99
+ return q_embed, k_embed
100
+
101
+ class RoPESDPAAttention(nn.Module):
102
+ def __init__(self, config: ScDiVaConfig):
103
+ super().__init__()
104
+ self.nhead = config.num_attention_heads
105
+ self.head_dim = config.hidden_size // self.nhead
106
+
107
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
108
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
109
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
110
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
111
+
112
+ self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta)
113
+ self.dropout = config.attention_probs_dropout_prob
114
 
115
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
116
+ B, L, _ = x.shape
117
+
118
+ q = self.q_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)
119
+ k = self.k_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)
120
+ v = self.v_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)
121
+
122
+ cos, sin = self.rope(v, seq_len=L)
123
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
124
+
125
+ # Use PyTorch's efficient SDPA
126
+ out = F.scaled_dot_product_attention(
127
+ q, k, v,
128
+ attn_mask=attn_mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=False
131
+ )
132
+
133
+ out = out.transpose(1, 2).contiguous().view(B, L, config.hidden_size)
134
+ return self.o_proj(out)
135
+
136
+ class ScDiVaBlock(nn.Module):
137
+ def __init__(self, config: ScDiVaConfig):
138
+ super().__init__()
139
+ self.norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
140
+ self.attn = RoPESDPAAttention(config)
141
+ self.norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
142
+ self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)
143
+ self.drop = nn.Dropout(config.hidden_dropout_prob)
144
+
145
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
146
+ h = x
147
+ x = self.norm1(x)
148
+ x = self.attn(x, attn_mask=attn_mask)
149
+ x = h + self.drop(x)
150
+
151
+ h = x
152
+ x = self.norm2(x)
153
+ x = self.mlp(x)
154
+ x = h + self.drop(x)
155
+ return x
156
+
157
+ # =============================================================================
158
+ # Outer Model Architecture
159
+ # =============================================================================
160
+
161
+ class GeneEmbedding(nn.Module):
162
+ def __init__(self, config: ScDiVaConfig):
163
+ super().__init__()
164
+ self.gene_projection = nn.Linear(config.num_genes, config.hidden_size)
165
+ # Updated to RMSNorm to match paper consistency
166
+ self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
167
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
168
+
169
+ def forward(self, gene_expression: torch.Tensor) -> torch.Tensor:
170
+ embeddings = self.gene_projection(gene_expression)
171
+ embeddings = self.layer_norm(embeddings)
172
+ embeddings = self.dropout(embeddings)
173
+ return embeddings
174
+
175
+ class TransformerEncoder(nn.Module):
176
+ def __init__(self, config: ScDiVaConfig):
177
+ super().__init__()
178
+ self.layers = nn.ModuleList([
179
+ ScDiVaBlock(config) for _ in range(config.num_hidden_layers)
180
+ ])
181
+
182
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
183
+ for layer in self.layers:
184
+ hidden_states = layer(hidden_states, attention_mask)
185
+ return hidden_states
186
+
187
+ class VariationalLayer(nn.Module):
188
+ def __init__(self, config: ScDiVaConfig):
189
+ super().__init__()
190
+ self.mu_projection = nn.Linear(config.hidden_size, config.latent_dim)
191
+ self.logvar_projection = nn.Linear(config.hidden_size, config.latent_dim)
192
+
193
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
194
+ std = torch.exp(0.5 * logvar)
195
+ eps = torch.randn_like(std)
196
+ return mu + eps * std
197
+
198
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
199
+ mu = self.mu_projection(hidden_states)
200
+ logvar = self.logvar_projection(hidden_states)
201
+ z = self.reparameterize(mu, logvar)
202
+ return z, mu, logvar
203
+
204
+ class AnnotationHead(nn.Module):
205
+ def __init__(self, config: ScDiVaConfig):
206
+ super().__init__()
207
+ self.dense = nn.Linear(config.latent_dim, config.hidden_size)
208
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
209
+ self.classifier = nn.Linear(config.hidden_size, config.num_cell_types)
210
+
211
+ def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
212
+ hidden = F.gelu(self.dense(latent_representation))
213
+ hidden = self.dropout(hidden)
214
+ logits = self.classifier(hidden)
215
+ return logits
216
+
217
+ class BatchIntegrationHead(nn.Module):
218
+ def __init__(self, config: ScDiVaConfig):
219
+ super().__init__()
220
+ self.dense = nn.Linear(config.latent_dim, config.hidden_size)
221
+ self.decoder = nn.Linear(config.hidden_size, config.num_genes)
222
+
223
+ def forward(self, latent_representation: torch.Tensor) -> torch.Tensor:
224
+ hidden = F.gelu(self.dense(latent_representation))
225
+ reconstructed = self.decoder(hidden)
226
+ return reconstructed
227
+
228
+ class ScDiVaModel(nn.Module):
229
+ """
230
+ ScDiVa: Single-cell Deep Variational Analysis Model
231
+ """
232
+ def __init__(self, config: ScDiVaConfig):
233
+ super().__init__()
234
+ self.config = config
235
+ self.gene_embedding = GeneEmbedding(config)
236
+ self.encoder = TransformerEncoder(config)
237
+ self.variational_layer = VariationalLayer(config)
238
+ self.annotation_head = AnnotationHead(config)
239
+ self.batch_integration_head = BatchIntegrationHead(config)
240
+
241
+ def encode(self, gene_expression: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
242
+ embeddings = self.gene_embedding(gene_expression)
243
+ # Add sequence dimension for Transformer [Batch, SeqLen=1, Dim]
244
+ # Note: If input is token sequence, normalization should happen before calling encode
245
+ embeddings = embeddings.unsqueeze(1)
246
+
247
+ encoded = self.encoder(embeddings, attention_mask)
248
+ encoded = encoded.squeeze(1)
249
+ z, mu, logvar = self.variational_layer(encoded)
250
+ return {"latent": z, "mu": mu, "logvar": logvar}
251
+
252
+ def predict(self, gene_expression: torch.Tensor, task: str = "annotation", attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
253
+ encoding = self.encode(gene_expression, attention_mask)
254
+ latent = encoding["latent"]
255
+ if task == "annotation":
256
+ return self.annotation_head(latent)
257
+ elif task == "batch_integration":
258
+ return self.batch_integration_head(latent)
259
  else:
260
+ raise ValueError(f"Unknown task: {task}")
261
+
262
+ @classmethod
263
+ def from_pretrained(
264
+ cls,
265
+ model_name_or_path: str,
266
+ map_location: Optional[str] = None,
267
+ strict: bool = True,
268
+ use_auth_token: Optional[str] = None,
269
+ ) -> "ScDiVaModel":
270
+ config = ScDiVaConfig()
271
+ model = cls(config)
272
+ if map_location is None:
273
+ map_location = "cpu"
274
+
275
+ ckpt_path: Optional[str] = None
276
+
277
+ # 1. Try Local
278
+ if os.path.exists(model_name_or_path):
279
+ if os.path.isfile(model_name_or_path):
280
+ ckpt_path = model_name_or_path
281
+ elif os.path.isdir(model_name_or_path):
282
+ for name in ["pytorch_model.bin", "model.safetensors", "model.pt"]:
283
+ p = os.path.join(model_name_or_path, name)
284
+ if os.path.exists(p):
285
+ ckpt_path = p
286
+ break
287
+
288
+ # 2. Try Hugging Face
289
+ if ckpt_path is None:
290
+ try:
291
+ from huggingface_hub import hf_hub_download
292
+ print(f"[ScDiVa] Downloading weights from HF: {model_name_or_path}")
293
+ try:
294
+ ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="model.safetensors", token=use_auth_token)
295
+ except:
296
+ ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="pytorch_model.bin", token=use_auth_token)
297
+ except ImportError:
298
+ pass
299
+ except Exception as e:
300
+ print(f"[ScDiVa] Warning: HF download failed: {e}")
301
+
302
+ # 3. Load or Fallback
303
+ if ckpt_path is None:
304
+ print(f"[ScDiVa] Warning: No weights found. Using random initialization (DEMO MODE).")
305
+ return model
306
+
307
+ print(f"[ScDiVa] Loading weights from {ckpt_path}...")
308
+ try:
309
+ state = torch.load(ckpt_path, map_location=map_location)
310
+ state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state
311
+ missing, unexpected = model.load_state_dict(state_dict, strict=strict)
312
+ if missing: print(f"Missing keys: {len(missing)}")
313
+ except Exception as e:
314
+ print(f"[ScDiVa] Error loading weights: {e}. Using random init.")
315
 
316
+ return model