virtual-cell-distil-bulk / modeling_virtual_cell_distil.py
danielle-miller-sayag's picture
Upload modeling_virtual_cell_distil.py with huggingface_hub
9216523 verified
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
def _get_activation(activation: str) -> nn.Module:
if activation == "prelu":
return nn.PReLU()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "tanh":
return nn.Tanh()
raise ValueError(f"Unsupported activation: {activation!r}")
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int = 512,
hidden_dim: Optional[List[int]] = None,
dropout: float = 0.0,
residual: bool = False,
activation: str = "prelu",
):
super().__init__()
if hidden_dim is None:
hidden_dim = [512, 512]
self.latent_dim = output_dim
self.residual = residual
self.network = nn.ModuleList()
if residual:
assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal"
for i in range(len(hidden_dim)):
if i == 0:
self.network.append(nn.Sequential(
nn.Linear(input_dim, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
_get_activation(activation),
))
else:
self.network.append(nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
_get_activation(activation),
))
self.network.append(nn.Linear(hidden_dim[-1], output_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.network):
if self.residual and (0 < i < len(self.network) - 1):
x = layer(x) + x
else:
x = layer(x)
return x
class VirtualCellDistilConfig(PretrainedConfig):
model_type = "virtual_cell_distil"
def __init__(
self,
n_genes: int = 18301,
output_dim: int = 512,
hidden_dim: Optional[List[int]] = None,
dropout: float = 0.0,
residual: bool = False,
activation: str = "prelu",
num_labels: int = 2,
classifier_dropout: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
self.n_genes = n_genes
self.output_dim = output_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else [512, 512]
self.dropout = dropout
self.residual = residual
self.activation = activation
self.num_labels = num_labels
self.classifier_dropout = classifier_dropout
class VirtualCellDistilModel(PreTrainedModel):
"""Pure encoder — returns 512-d patient embeddings from bulk expression."""
config_class = VirtualCellDistilConfig
def __init__(self, config: VirtualCellDistilConfig):
super().__init__(config)
self.encoder = MLP(
input_dim=config.n_genes,
output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
dropout=config.dropout,
residual=config.residual,
activation=config.activation,
)
def forward(self, input_ids: torch.Tensor, **kwargs) -> dict:
return {"embeddings": self.encoder(input_ids)}
class VirtualCellDistilForSequenceClassification(PreTrainedModel):
"""
Encoder + linear classification head.
The encoder is initialised from pretrained distilled weights.
The classification head is randomly initialised and trained on your labels.
Use ignore_mismatched_sizes=True when loading from the pretrained checkpoint.
"""
config_class = VirtualCellDistilConfig
def __init__(self, config: VirtualCellDistilConfig):
super().__init__(config)
self.encoder = MLP(
input_dim=config.n_genes,
output_dim=config.output_dim,
hidden_dim=config.hidden_dim,
dropout=config.dropout,
residual=config.residual,
activation=config.activation,
)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.output_dim, config.num_labels)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
embeddings = self.encoder(input_ids)
logits = self.classifier(self.dropout(embeddings))
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits, "embeddings": embeddings}