from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput def _get_activation(activation: str) -> nn.Module: if activation == "prelu": return nn.PReLU() elif activation == "relu": return nn.ReLU() elif activation == "gelu": return nn.GELU() elif activation == "tanh": return nn.Tanh() raise ValueError(f"Unsupported activation: {activation!r}") class MLP(nn.Module): def __init__( self, input_dim: int, output_dim: int = 128, hidden_dim: Optional[List[int]] = None, dropout: float = 0.0, residual: bool = False, activation: str = "prelu", ): super().__init__() if hidden_dim is None: hidden_dim = [1024, 1024] self.input_dim = input_dim self.latent_dim = output_dim self.residual = residual self.dropout = dropout self.activation = activation self.network = nn.ModuleList() if residual: assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal" for i in range(len(hidden_dim)): if i == 0: self.network.append( nn.Sequential( nn.Linear(input_dim, hidden_dim[i]), nn.BatchNorm1d(hidden_dim[i]), _get_activation(activation), ) ) else: self.network.append( nn.Sequential( nn.Dropout(p=dropout), nn.Linear(hidden_dim[i - 1], hidden_dim[i]), nn.BatchNorm1d(hidden_dim[i]), _get_activation(activation), ) ) self.network.append(nn.Linear(hidden_dim[-1], output_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: assert torch.is_tensor(x) and x.ndim == 2, ( f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}" ) assert x.shape[0] > 1, ( f"BatchNorm requires batch size > 1, got {x.shape[0]}. " "Use model.eval() for single-sample inference." ) for i, layer in enumerate(self.network): if self.residual and 0 < i < len(self.network) - 1: x = layer(x) + x else: x = layer(x) return x class MLPCellEmbedder(nn.Module): # Thin wrapper that preserves the .encoder attribute name required # for state-dict key compatibility with the checkpoint. def __init__( self, n_genes: int, output_dim: int = 128, hidden_dim: Optional[List[int]] = None, dropout: float = 0.1, residual: bool = False, activation: str = "prelu", ): super().__init__() if hidden_dim is None: hidden_dim = [1024, 1024] self.n_genes = n_genes self.output_dim = output_dim self.encoder = MLP( input_dim=n_genes, output_dim=output_dim, hidden_dim=hidden_dim, dropout=dropout, residual=residual, activation=activation, ) def forward(self, x: torch.Tensor) -> torch.Tensor: assert torch.is_tensor(x) and x.ndim == 2, ( f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}" ) return self.encoder(x) class AttentionAggregator(nn.Module): def __init__(self, embedding_dim: int, hidden_dim: int = 128): super().__init__() self.attention_net = nn.Sequential( nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), ) def aggregate( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x: [batch, num_cells, embedding_dim] mask: [batch, num_cells] — 1=valid, 0=ignore (optional) Returns: [batch, embedding_dim] """ if mask is not None: assert mask.sum(dim=1).min() > 0, "All samples must have at least one valid cell" scores = self.attention_net(x).squeeze(-1) if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) weights = torch.softmax(scores, dim=1).unsqueeze(-1) return (x * weights).sum(dim=1) class PatientEmbedder(nn.Module): def __init__(self, cell_embedder: nn.Module, aggregator: nn.Module): super().__init__() self.cell_embedder = cell_embedder self.aggregator = aggregator def forward( self, cell_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: cell_matrix: [batch, num_cells, num_genes] mask: [batch, num_cells] — optional Returns: [batch, embedding_dim] """ batch_size, num_cells, num_genes = cell_matrix.shape flat = cell_matrix.view(-1, num_genes) embeddings_flat = self.cell_embedder(flat) embeddings = embeddings_flat.view(batch_size, num_cells, -1) return self.aggregator.aggregate(embeddings, mask) def get_embedding_dim(self) -> int: return self.cell_embedder.output_dim class CrossEntropyLossViews(nn.Module): """Cross-entropy loss that averages per-entity (patient) across augmented views.""" def __init__(self, class_weights: Optional[torch.Tensor] = None): super().__init__() self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, reduction="none") def forward( self, predictions: torch.Tensor, labels: torch.Tensor, entity_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: sample_losses = self.ce_loss(predictions, labels) if entity_ids is None: return torch.mean(sample_losses) unique_entities, inverse_indices, counts = torch.unique( entity_ids, return_inverse=True, return_counts=True ) entity_sums = torch.zeros( len(unique_entities), device=sample_losses.device, dtype=sample_losses.dtype ) entity_sums.scatter_add_(0, inverse_indices, sample_losses) return torch.mean(entity_sums / counts.float()) class VirtualCellPatientConfig(PretrainedConfig): model_type = "virtual_cell_patient" def __init__( self, n_genes: int = 18301, embed_dim: int = 512, hidden_dim: Optional[List[int]] = None, dropout: float = 0.1, residual: bool = False, activation: str = "prelu", attention_hidden_dim: int = 512, num_classes: int = 10, classifier_dropout: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.n_genes = n_genes self.embed_dim = embed_dim self.hidden_dim = hidden_dim if hidden_dim is not None else [4096, 1024] self.dropout = dropout self.residual = residual self.activation = activation self.attention_hidden_dim = attention_hidden_dim self.num_classes = num_classes self.classifier_dropout = classifier_dropout class VirtualCellPatientModel(PreTrainedModel): config_class = VirtualCellPatientConfig def __init__(self, config: VirtualCellPatientConfig): super().__init__(config) cell_embedder = MLPCellEmbedder( n_genes=config.n_genes, output_dim=config.embed_dim, hidden_dim=config.hidden_dim, dropout=config.dropout, residual=config.residual, activation=config.activation, ) aggregator = AttentionAggregator( embedding_dim=config.embed_dim, hidden_dim=config.attention_hidden_dim, ) self.patient_embedder = PatientEmbedder(cell_embedder, aggregator) self.classifier = nn.Sequential( nn.Dropout(config.classifier_dropout), nn.Linear(config.embed_dim, config.num_classes), ) self.loss_fn = CrossEntropyLossViews() def _init_weights(self, module): pass def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, entity_id: Optional[torch.Tensor] = None, **kwargs, ) -> SequenceClassifierOutput: """ Args: input_ids: [batch, num_cells, num_genes] log-normalized float32 expression attention_mask: [batch, num_cells] 1=valid, 0=ignore (optional) labels: [batch] integer class indices (optional, for loss) entity_id: [batch] patient IDs grouping augmented views (optional) Returns: SequenceClassifierOutput with .loss (when labels given) and .logits [batch, num_classes] """ embeddings = self.patient_embedder(input_ids, attention_mask) logits = self.classifier(embeddings) loss = None if labels is not None: loss = ( self.loss_fn(logits, labels, entity_id) if entity_id is not None else F.cross_entropy(logits, labels) ) return SequenceClassifierOutput(loss=loss, logits=logits)