| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
| def _get_activation(activation: str) -> nn.Module: |
| if activation == "prelu": |
| return nn.PReLU() |
| elif activation == "relu": |
| return nn.ReLU() |
| elif activation == "gelu": |
| return nn.GELU() |
| elif activation == "tanh": |
| return nn.Tanh() |
| raise ValueError(f"Unsupported activation: {activation!r}") |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int = 128, |
| hidden_dim: Optional[List[int]] = None, |
| dropout: float = 0.0, |
| residual: bool = False, |
| activation: str = "prelu", |
| ): |
| super().__init__() |
| if hidden_dim is None: |
| hidden_dim = [1024, 1024] |
| self.input_dim = input_dim |
| self.latent_dim = output_dim |
| self.residual = residual |
| self.dropout = dropout |
| self.activation = activation |
| self.network = nn.ModuleList() |
|
|
| if residual: |
| assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal" |
|
|
| for i in range(len(hidden_dim)): |
| if i == 0: |
| self.network.append( |
| nn.Sequential( |
| nn.Linear(input_dim, hidden_dim[i]), |
| nn.BatchNorm1d(hidden_dim[i]), |
| _get_activation(activation), |
| ) |
| ) |
| else: |
| self.network.append( |
| nn.Sequential( |
| nn.Dropout(p=dropout), |
| nn.Linear(hidden_dim[i - 1], hidden_dim[i]), |
| nn.BatchNorm1d(hidden_dim[i]), |
| _get_activation(activation), |
| ) |
| ) |
| self.network.append(nn.Linear(hidden_dim[-1], output_dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert torch.is_tensor(x) and x.ndim == 2, ( |
| f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}" |
| ) |
| assert x.shape[0] > 1, ( |
| f"BatchNorm requires batch size > 1, got {x.shape[0]}. " |
| "Use model.eval() for single-sample inference." |
| ) |
| for i, layer in enumerate(self.network): |
| if self.residual and 0 < i < len(self.network) - 1: |
| x = layer(x) + x |
| else: |
| x = layer(x) |
| return x |
|
|
|
|
| class MLPCellEmbedder(nn.Module): |
| |
| |
| def __init__( |
| self, |
| n_genes: int, |
| output_dim: int = 128, |
| hidden_dim: Optional[List[int]] = None, |
| dropout: float = 0.1, |
| residual: bool = False, |
| activation: str = "prelu", |
| ): |
| super().__init__() |
| if hidden_dim is None: |
| hidden_dim = [1024, 1024] |
| self.n_genes = n_genes |
| self.output_dim = output_dim |
| self.encoder = MLP( |
| input_dim=n_genes, |
| output_dim=output_dim, |
| hidden_dim=hidden_dim, |
| dropout=dropout, |
| residual=residual, |
| activation=activation, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert torch.is_tensor(x) and x.ndim == 2, ( |
| f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}" |
| ) |
| return self.encoder(x) |
|
|
|
|
| class AttentionAggregator(nn.Module): |
| def __init__(self, embedding_dim: int, hidden_dim: int = 128): |
| super().__init__() |
| self.attention_net = nn.Sequential( |
| nn.Linear(embedding_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def aggregate( |
| self, x: torch.Tensor, mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: [batch, num_cells, embedding_dim] |
| mask: [batch, num_cells] — 1=valid, 0=ignore (optional) |
| Returns: |
| [batch, embedding_dim] |
| """ |
| if mask is not None: |
| assert mask.sum(dim=1).min() > 0, "All samples must have at least one valid cell" |
| scores = self.attention_net(x).squeeze(-1) |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, float("-inf")) |
| weights = torch.softmax(scores, dim=1).unsqueeze(-1) |
| return (x * weights).sum(dim=1) |
|
|
|
|
| class PatientEmbedder(nn.Module): |
| def __init__(self, cell_embedder: nn.Module, aggregator: nn.Module): |
| super().__init__() |
| self.cell_embedder = cell_embedder |
| self.aggregator = aggregator |
|
|
| def forward( |
| self, cell_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| cell_matrix: [batch, num_cells, num_genes] |
| mask: [batch, num_cells] — optional |
| Returns: |
| [batch, embedding_dim] |
| """ |
| batch_size, num_cells, num_genes = cell_matrix.shape |
| flat = cell_matrix.view(-1, num_genes) |
| embeddings_flat = self.cell_embedder(flat) |
| embeddings = embeddings_flat.view(batch_size, num_cells, -1) |
| return self.aggregator.aggregate(embeddings, mask) |
|
|
| def get_embedding_dim(self) -> int: |
| return self.cell_embedder.output_dim |
|
|
|
|
| class CrossEntropyLossViews(nn.Module): |
| """Cross-entropy loss that averages per-entity (patient) across augmented views.""" |
|
|
| def __init__(self, class_weights: Optional[torch.Tensor] = None): |
| super().__init__() |
| self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, reduction="none") |
|
|
| def forward( |
| self, |
| predictions: torch.Tensor, |
| labels: torch.Tensor, |
| entity_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| sample_losses = self.ce_loss(predictions, labels) |
| if entity_ids is None: |
| return torch.mean(sample_losses) |
| unique_entities, inverse_indices, counts = torch.unique( |
| entity_ids, return_inverse=True, return_counts=True |
| ) |
| entity_sums = torch.zeros( |
| len(unique_entities), device=sample_losses.device, dtype=sample_losses.dtype |
| ) |
| entity_sums.scatter_add_(0, inverse_indices, sample_losses) |
| return torch.mean(entity_sums / counts.float()) |
|
|
|
|
| class VirtualCellPatientConfig(PretrainedConfig): |
| model_type = "virtual_cell_patient" |
|
|
| def __init__( |
| self, |
| n_genes: int = 18301, |
| embed_dim: int = 512, |
| hidden_dim: Optional[List[int]] = None, |
| dropout: float = 0.1, |
| residual: bool = False, |
| activation: str = "prelu", |
| attention_hidden_dim: int = 512, |
| num_classes: int = 10, |
| classifier_dropout: float = 0.1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.n_genes = n_genes |
| self.embed_dim = embed_dim |
| self.hidden_dim = hidden_dim if hidden_dim is not None else [4096, 1024] |
| self.dropout = dropout |
| self.residual = residual |
| self.activation = activation |
| self.attention_hidden_dim = attention_hidden_dim |
| self.num_classes = num_classes |
| self.classifier_dropout = classifier_dropout |
|
|
|
|
| class VirtualCellPatientModel(PreTrainedModel): |
| config_class = VirtualCellPatientConfig |
|
|
| def __init__(self, config: VirtualCellPatientConfig): |
| super().__init__(config) |
| cell_embedder = MLPCellEmbedder( |
| n_genes=config.n_genes, |
| output_dim=config.embed_dim, |
| hidden_dim=config.hidden_dim, |
| dropout=config.dropout, |
| residual=config.residual, |
| activation=config.activation, |
| ) |
| aggregator = AttentionAggregator( |
| embedding_dim=config.embed_dim, |
| hidden_dim=config.attention_hidden_dim, |
| ) |
| self.patient_embedder = PatientEmbedder(cell_embedder, aggregator) |
| self.classifier = nn.Sequential( |
| nn.Dropout(config.classifier_dropout), |
| nn.Linear(config.embed_dim, config.num_classes), |
| ) |
| self.loss_fn = CrossEntropyLossViews() |
|
|
| def _init_weights(self, module): |
| pass |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| entity_id: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> SequenceClassifierOutput: |
| """ |
| Args: |
| input_ids: [batch, num_cells, num_genes] log-normalized float32 expression |
| attention_mask: [batch, num_cells] 1=valid, 0=ignore (optional) |
| labels: [batch] integer class indices (optional, for loss) |
| entity_id: [batch] patient IDs grouping augmented views (optional) |
| Returns: |
| SequenceClassifierOutput with .loss (when labels given) and .logits [batch, num_classes] |
| """ |
| embeddings = self.patient_embedder(input_ids, attention_mask) |
| logits = self.classifier(embeddings) |
|
|
| loss = None |
| if labels is not None: |
| loss = ( |
| self.loss_fn(logits, labels, entity_id) |
| if entity_id is not None |
| else F.cross_entropy(logits, labels) |
| ) |
|
|
| return SequenceClassifierOutput(loss=loss, logits=logits) |
|
|