|
|
""" |
|
|
ScDiVa: A Foundation Model for Single-cell Genomics |
|
|
Model Architecture Definition |
|
|
|
|
|
This file contains the core architecture definition of ScDiVa. |
|
|
It integrates SwiGLU, RoPE, and RMSNorm as described in the paper. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Dict, Tuple, Union |
|
|
import math |
|
|
import os |
|
|
|
|
|
class ScDiVaConfig: |
|
|
def __init__( |
|
|
self, |
|
|
num_genes: int = 41818, |
|
|
hidden_size: int = 512, |
|
|
num_hidden_layers: int = 12, |
|
|
num_attention_heads: int = 8, |
|
|
intermediate_size: int = 2048, |
|
|
hidden_dropout_prob: float = 0.1, |
|
|
attention_probs_dropout_prob: float = 0.1, |
|
|
max_position_embeddings: int = 1200, |
|
|
layer_norm_eps: float = 1e-5, |
|
|
latent_dim: int = 128, |
|
|
num_cell_types: int = 100, |
|
|
use_variational: bool = True, |
|
|
rope_theta: float = 10000.0, |
|
|
**kwargs |
|
|
): |
|
|
self.num_genes = num_genes |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.intermediate_size = intermediate_size |
|
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.layer_norm_eps = layer_norm_eps |
|
|
self.latent_dim = latent_dim |
|
|
self.num_cell_types = num_cell_types |
|
|
self.use_variational = use_variational |
|
|
self.rope_theta = rope_theta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
x_float = x.float() |
|
|
output = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
return (output * self.weight.float()).type_as(x) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, dim: int, hidden_dim: int): |
|
|
super().__init__() |
|
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim, max_seq_len=4096, base=10000.0): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
def forward(self, x, seq_len=None): |
|
|
if seq_len is None: |
|
|
seq_len = x.shape[1] |
|
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
return emb.cos()[None, :, :], emb.sin()[None, :, :] |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
|
|
|
def rotate_half(x): |
|
|
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
cos = cos.unsqueeze(2) |
|
|
sin = sin.unsqueeze(2) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
class RoPESDPAAttention(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.nhead = config.num_attention_heads |
|
|
self.head_dim = config.hidden_size // self.nhead |
|
|
|
|
|
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
|
|
|
self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta) |
|
|
self.dropout = config.attention_probs_dropout_prob |
|
|
|
|
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
B, L, _ = x.shape |
|
|
|
|
|
q = self.q_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2) |
|
|
|
|
|
cos, sin = self.rope(v, seq_len=L) |
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, L, config.hidden_size) |
|
|
return self.o_proj(out) |
|
|
|
|
|
class ScDiVaBlock(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.attn = RoPESDPAAttention(config) |
|
|
self.norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.mlp = SwiGLU(config.hidden_size, config.intermediate_size) |
|
|
self.drop = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): |
|
|
h = x |
|
|
x = self.norm1(x) |
|
|
x = self.attn(x, attn_mask=attn_mask) |
|
|
x = h + self.drop(x) |
|
|
|
|
|
h = x |
|
|
x = self.norm2(x) |
|
|
x = self.mlp(x) |
|
|
x = h + self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeneEmbedding(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.gene_projection = nn.Linear(config.num_genes, config.hidden_size) |
|
|
|
|
|
self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, gene_expression: torch.Tensor) -> torch.Tensor: |
|
|
embeddings = self.gene_projection(gene_expression) |
|
|
embeddings = self.layer_norm(embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return embeddings |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([ |
|
|
ScDiVaBlock(config) for _ in range(config.num_hidden_layers) |
|
|
]) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, attention_mask) |
|
|
return hidden_states |
|
|
|
|
|
class VariationalLayer(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.mu_projection = nn.Linear(config.hidden_size, config.latent_dim) |
|
|
self.logvar_projection = nn.Linear(config.hidden_size, config.latent_dim) |
|
|
|
|
|
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: |
|
|
std = torch.exp(0.5 * logvar) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
mu = self.mu_projection(hidden_states) |
|
|
logvar = self.logvar_projection(hidden_states) |
|
|
z = self.reparameterize(mu, logvar) |
|
|
return z, mu, logvar |
|
|
|
|
|
class AnnotationHead(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.latent_dim, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_cell_types) |
|
|
|
|
|
def forward(self, latent_representation: torch.Tensor) -> torch.Tensor: |
|
|
hidden = F.gelu(self.dense(latent_representation)) |
|
|
hidden = self.dropout(hidden) |
|
|
logits = self.classifier(hidden) |
|
|
return logits |
|
|
|
|
|
class BatchIntegrationHead(nn.Module): |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.latent_dim, config.hidden_size) |
|
|
self.decoder = nn.Linear(config.hidden_size, config.num_genes) |
|
|
|
|
|
def forward(self, latent_representation: torch.Tensor) -> torch.Tensor: |
|
|
hidden = F.gelu(self.dense(latent_representation)) |
|
|
reconstructed = self.decoder(hidden) |
|
|
return reconstructed |
|
|
|
|
|
class ScDiVaModel(nn.Module): |
|
|
""" |
|
|
ScDiVa: Single-cell Deep Variational Analysis Model |
|
|
""" |
|
|
def __init__(self, config: ScDiVaConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.gene_embedding = GeneEmbedding(config) |
|
|
self.encoder = TransformerEncoder(config) |
|
|
self.variational_layer = VariationalLayer(config) |
|
|
self.annotation_head = AnnotationHead(config) |
|
|
self.batch_integration_head = BatchIntegrationHead(config) |
|
|
|
|
|
def encode(self, gene_expression: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
|
|
embeddings = self.gene_embedding(gene_expression) |
|
|
|
|
|
|
|
|
embeddings = embeddings.unsqueeze(1) |
|
|
|
|
|
encoded = self.encoder(embeddings, attention_mask) |
|
|
encoded = encoded.squeeze(1) |
|
|
z, mu, logvar = self.variational_layer(encoded) |
|
|
return {"latent": z, "mu": mu, "logvar": logvar} |
|
|
|
|
|
def predict(self, gene_expression: torch.Tensor, task: str = "annotation", attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
encoding = self.encode(gene_expression, attention_mask) |
|
|
latent = encoding["latent"] |
|
|
if task == "annotation": |
|
|
return self.annotation_head(latent) |
|
|
elif task == "batch_integration": |
|
|
return self.batch_integration_head(latent) |
|
|
else: |
|
|
raise ValueError(f"Unknown task: {task}") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
model_name_or_path: str, |
|
|
map_location: Optional[str] = None, |
|
|
strict: bool = True, |
|
|
use_auth_token: Optional[str] = None, |
|
|
) -> "ScDiVaModel": |
|
|
config = ScDiVaConfig() |
|
|
model = cls(config) |
|
|
if map_location is None: |
|
|
map_location = "cpu" |
|
|
|
|
|
ckpt_path: Optional[str] = None |
|
|
|
|
|
|
|
|
if os.path.exists(model_name_or_path): |
|
|
if os.path.isfile(model_name_or_path): |
|
|
ckpt_path = model_name_or_path |
|
|
elif os.path.isdir(model_name_or_path): |
|
|
for name in ["pytorch_model.bin", "model.safetensors", "model.pt"]: |
|
|
p = os.path.join(model_name_or_path, name) |
|
|
if os.path.exists(p): |
|
|
ckpt_path = p |
|
|
break |
|
|
|
|
|
|
|
|
if ckpt_path is None: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
print(f"[ScDiVa] Downloading weights from HF: {model_name_or_path}") |
|
|
try: |
|
|
ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="model.safetensors", token=use_auth_token) |
|
|
except: |
|
|
ckpt_path = hf_hub_download(repo_id=model_name_or_path, filename="pytorch_model.bin", token=use_auth_token) |
|
|
except ImportError: |
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"[ScDiVa] Warning: HF download failed: {e}") |
|
|
|
|
|
|
|
|
if ckpt_path is None: |
|
|
print(f"[ScDiVa] Warning: No weights found. Using random initialization (DEMO MODE).") |
|
|
return model |
|
|
|
|
|
print(f"[ScDiVa] Loading weights from {ckpt_path}...") |
|
|
try: |
|
|
state = torch.load(ckpt_path, map_location=map_location) |
|
|
state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state |
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=strict) |
|
|
if missing: print(f"Missing keys: {len(missing)}") |
|
|
except Exception as e: |
|
|
print(f"[ScDiVa] Error loading weights: {e}. Using random init.") |
|
|
|
|
|
return model |