from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel def _get_activation(activation: str) -> nn.Module: if activation == "prelu": return nn.PReLU() elif activation == "relu": return nn.ReLU() elif activation == "gelu": return nn.GELU() elif activation == "tanh": return nn.Tanh() raise ValueError(f"Unsupported activation: {activation!r}") class MLP(nn.Module): def __init__( self, input_dim: int, output_dim: int = 512, hidden_dim: Optional[List[int]] = None, dropout: float = 0.0, residual: bool = False, activation: str = "prelu", ): super().__init__() if hidden_dim is None: hidden_dim = [512, 512] self.latent_dim = output_dim self.residual = residual self.network = nn.ModuleList() if residual: assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal" for i in range(len(hidden_dim)): if i == 0: self.network.append(nn.Sequential( nn.Linear(input_dim, hidden_dim[i]), nn.BatchNorm1d(hidden_dim[i]), _get_activation(activation), )) else: self.network.append(nn.Sequential( nn.Dropout(p=dropout), nn.Linear(hidden_dim[i - 1], hidden_dim[i]), nn.BatchNorm1d(hidden_dim[i]), _get_activation(activation), )) self.network.append(nn.Linear(hidden_dim[-1], output_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.network): if self.residual and (0 < i < len(self.network) - 1): x = layer(x) + x else: x = layer(x) return x class VirtualCellDistilConfig(PretrainedConfig): model_type = "virtual_cell_distil" def __init__( self, n_genes: int = 18301, output_dim: int = 512, hidden_dim: Optional[List[int]] = None, dropout: float = 0.0, residual: bool = False, activation: str = "prelu", num_labels: int = 2, classifier_dropout: float = 0.1, **kwargs, ): super().__init__(**kwargs) self.n_genes = n_genes self.output_dim = output_dim self.hidden_dim = hidden_dim if hidden_dim is not None else [512, 512] self.dropout = dropout self.residual = residual self.activation = activation self.num_labels = num_labels self.classifier_dropout = classifier_dropout class VirtualCellDistilModel(PreTrainedModel): """Pure encoder — returns 512-d patient embeddings from bulk expression.""" config_class = VirtualCellDistilConfig def __init__(self, config: VirtualCellDistilConfig): super().__init__(config) self.encoder = MLP( input_dim=config.n_genes, output_dim=config.output_dim, hidden_dim=config.hidden_dim, dropout=config.dropout, residual=config.residual, activation=config.activation, ) def forward(self, input_ids: torch.Tensor, **kwargs) -> dict: return {"embeddings": self.encoder(input_ids)} class VirtualCellDistilForSequenceClassification(PreTrainedModel): """ Encoder + linear classification head. The encoder is initialised from pretrained distilled weights. The classification head is randomly initialised and trained on your labels. Use ignore_mismatched_sizes=True when loading from the pretrained checkpoint. """ config_class = VirtualCellDistilConfig def __init__(self, config: VirtualCellDistilConfig): super().__init__(config) self.encoder = MLP( input_dim=config.n_genes, output_dim=config.output_dim, hidden_dim=config.hidden_dim, dropout=config.dropout, residual=config.residual, activation=config.activation, ) self.dropout = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.output_dim, config.num_labels) def forward( self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, **kwargs, ) -> dict: embeddings = self.encoder(input_ids) logits = self.classifier(self.dropout(embeddings)) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) return {"loss": loss, "logits": logits, "embeddings": embeddings}