Transformers
Safetensors
virtual_cell_distil
biology
genomics
bulk-rna-seq
patient-embedding
custom_code
Instructions to use ConvergeBio/virtual-cell-distil-bulk with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ConvergeBio/virtual-cell-distil-bulk with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ConvergeBio/virtual-cell-distil-bulk", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| def _get_activation(activation: str) -> nn.Module: | |
| if activation == "prelu": | |
| return nn.PReLU() | |
| elif activation == "relu": | |
| return nn.ReLU() | |
| elif activation == "gelu": | |
| return nn.GELU() | |
| elif activation == "tanh": | |
| return nn.Tanh() | |
| raise ValueError(f"Unsupported activation: {activation!r}") | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int = 512, | |
| hidden_dim: Optional[List[int]] = None, | |
| dropout: float = 0.0, | |
| residual: bool = False, | |
| activation: str = "prelu", | |
| ): | |
| super().__init__() | |
| if hidden_dim is None: | |
| hidden_dim = [512, 512] | |
| self.latent_dim = output_dim | |
| self.residual = residual | |
| self.network = nn.ModuleList() | |
| if residual: | |
| assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal" | |
| for i in range(len(hidden_dim)): | |
| if i == 0: | |
| self.network.append(nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim[i]), | |
| nn.BatchNorm1d(hidden_dim[i]), | |
| _get_activation(activation), | |
| )) | |
| else: | |
| self.network.append(nn.Sequential( | |
| nn.Dropout(p=dropout), | |
| nn.Linear(hidden_dim[i - 1], hidden_dim[i]), | |
| nn.BatchNorm1d(hidden_dim[i]), | |
| _get_activation(activation), | |
| )) | |
| self.network.append(nn.Linear(hidden_dim[-1], output_dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for i, layer in enumerate(self.network): | |
| if self.residual and (0 < i < len(self.network) - 1): | |
| x = layer(x) + x | |
| else: | |
| x = layer(x) | |
| return x | |
| class VirtualCellDistilConfig(PretrainedConfig): | |
| model_type = "virtual_cell_distil" | |
| def __init__( | |
| self, | |
| n_genes: int = 18301, | |
| output_dim: int = 512, | |
| hidden_dim: Optional[List[int]] = None, | |
| dropout: float = 0.0, | |
| residual: bool = False, | |
| activation: str = "prelu", | |
| num_labels: int = 2, | |
| classifier_dropout: float = 0.1, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.n_genes = n_genes | |
| self.output_dim = output_dim | |
| self.hidden_dim = hidden_dim if hidden_dim is not None else [512, 512] | |
| self.dropout = dropout | |
| self.residual = residual | |
| self.activation = activation | |
| self.num_labels = num_labels | |
| self.classifier_dropout = classifier_dropout | |
| class VirtualCellDistilModel(PreTrainedModel): | |
| """Pure encoder — returns 512-d patient embeddings from bulk expression.""" | |
| config_class = VirtualCellDistilConfig | |
| def __init__(self, config: VirtualCellDistilConfig): | |
| super().__init__(config) | |
| self.encoder = MLP( | |
| input_dim=config.n_genes, | |
| output_dim=config.output_dim, | |
| hidden_dim=config.hidden_dim, | |
| dropout=config.dropout, | |
| residual=config.residual, | |
| activation=config.activation, | |
| ) | |
| def forward(self, input_ids: torch.Tensor, **kwargs) -> dict: | |
| return {"embeddings": self.encoder(input_ids)} | |
| class VirtualCellDistilForSequenceClassification(PreTrainedModel): | |
| """ | |
| Encoder + linear classification head. | |
| The encoder is initialised from pretrained distilled weights. | |
| The classification head is randomly initialised and trained on your labels. | |
| Use ignore_mismatched_sizes=True when loading from the pretrained checkpoint. | |
| """ | |
| config_class = VirtualCellDistilConfig | |
| def __init__(self, config: VirtualCellDistilConfig): | |
| super().__init__(config) | |
| self.encoder = MLP( | |
| input_dim=config.n_genes, | |
| output_dim=config.output_dim, | |
| hidden_dim=config.hidden_dim, | |
| dropout=config.dropout, | |
| residual=config.residual, | |
| activation=config.activation, | |
| ) | |
| self.dropout = nn.Dropout(config.classifier_dropout) | |
| self.classifier = nn.Linear(config.output_dim, config.num_labels) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> dict: | |
| embeddings = self.encoder(input_ids) | |
| logits = self.classifier(self.dropout(embeddings)) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy(logits, labels) | |
| return {"loss": loss, "logits": logits, "embeddings": embeddings} | |