virtual-cell-patient / modeling_virtual_cell.py
danielle-miller-sayag's picture
Upload modeling_virtual_cell.py with huggingface_hub
8b65612 verified
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
def _get_activation(activation: str) -> nn.Module:
if activation == "prelu":
return nn.PReLU()
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "tanh":
return nn.Tanh()
raise ValueError(f"Unsupported activation: {activation!r}")
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int = 128,
hidden_dim: Optional[List[int]] = None,
dropout: float = 0.0,
residual: bool = False,
activation: str = "prelu",
):
super().__init__()
if hidden_dim is None:
hidden_dim = [1024, 1024]
self.input_dim = input_dim
self.latent_dim = output_dim
self.residual = residual
self.dropout = dropout
self.activation = activation
self.network = nn.ModuleList()
if residual:
assert len(set(hidden_dim)) == 1, "Residual connections require all hidden dims to be equal"
for i in range(len(hidden_dim)):
if i == 0:
self.network.append(
nn.Sequential(
nn.Linear(input_dim, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
_get_activation(activation),
)
)
else:
self.network.append(
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
_get_activation(activation),
)
)
self.network.append(nn.Linear(hidden_dim[-1], output_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert torch.is_tensor(x) and x.ndim == 2, (
f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}"
)
assert x.shape[0] > 1, (
f"BatchNorm requires batch size > 1, got {x.shape[0]}. "
"Use model.eval() for single-sample inference."
)
for i, layer in enumerate(self.network):
if self.residual and 0 < i < len(self.network) - 1:
x = layer(x) + x
else:
x = layer(x)
return x
class MLPCellEmbedder(nn.Module):
# Thin wrapper that preserves the .encoder attribute name required
# for state-dict key compatibility with the checkpoint.
def __init__(
self,
n_genes: int,
output_dim: int = 128,
hidden_dim: Optional[List[int]] = None,
dropout: float = 0.1,
residual: bool = False,
activation: str = "prelu",
):
super().__init__()
if hidden_dim is None:
hidden_dim = [1024, 1024]
self.n_genes = n_genes
self.output_dim = output_dim
self.encoder = MLP(
input_dim=n_genes,
output_dim=output_dim,
hidden_dim=hidden_dim,
dropout=dropout,
residual=residual,
activation=activation,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert torch.is_tensor(x) and x.ndim == 2, (
f"Expected 2D tensor, got {type(x).__name__} shape {getattr(x, 'shape', None)}"
)
return self.encoder(x)
class AttentionAggregator(nn.Module):
def __init__(self, embedding_dim: int, hidden_dim: int = 128):
super().__init__()
self.attention_net = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def aggregate(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: [batch, num_cells, embedding_dim]
mask: [batch, num_cells] — 1=valid, 0=ignore (optional)
Returns:
[batch, embedding_dim]
"""
if mask is not None:
assert mask.sum(dim=1).min() > 0, "All samples must have at least one valid cell"
scores = self.attention_net(x).squeeze(-1)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = torch.softmax(scores, dim=1).unsqueeze(-1)
return (x * weights).sum(dim=1)
class PatientEmbedder(nn.Module):
def __init__(self, cell_embedder: nn.Module, aggregator: nn.Module):
super().__init__()
self.cell_embedder = cell_embedder
self.aggregator = aggregator
def forward(
self, cell_matrix: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
cell_matrix: [batch, num_cells, num_genes]
mask: [batch, num_cells] — optional
Returns:
[batch, embedding_dim]
"""
batch_size, num_cells, num_genes = cell_matrix.shape
flat = cell_matrix.view(-1, num_genes)
embeddings_flat = self.cell_embedder(flat)
embeddings = embeddings_flat.view(batch_size, num_cells, -1)
return self.aggregator.aggregate(embeddings, mask)
def get_embedding_dim(self) -> int:
return self.cell_embedder.output_dim
class CrossEntropyLossViews(nn.Module):
"""Cross-entropy loss that averages per-entity (patient) across augmented views."""
def __init__(self, class_weights: Optional[torch.Tensor] = None):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, reduction="none")
def forward(
self,
predictions: torch.Tensor,
labels: torch.Tensor,
entity_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sample_losses = self.ce_loss(predictions, labels)
if entity_ids is None:
return torch.mean(sample_losses)
unique_entities, inverse_indices, counts = torch.unique(
entity_ids, return_inverse=True, return_counts=True
)
entity_sums = torch.zeros(
len(unique_entities), device=sample_losses.device, dtype=sample_losses.dtype
)
entity_sums.scatter_add_(0, inverse_indices, sample_losses)
return torch.mean(entity_sums / counts.float())
class VirtualCellPatientConfig(PretrainedConfig):
model_type = "virtual_cell_patient"
def __init__(
self,
n_genes: int = 18301,
embed_dim: int = 512,
hidden_dim: Optional[List[int]] = None,
dropout: float = 0.1,
residual: bool = False,
activation: str = "prelu",
attention_hidden_dim: int = 512,
num_classes: int = 10,
classifier_dropout: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
self.n_genes = n_genes
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else [4096, 1024]
self.dropout = dropout
self.residual = residual
self.activation = activation
self.attention_hidden_dim = attention_hidden_dim
self.num_classes = num_classes
self.classifier_dropout = classifier_dropout
class VirtualCellPatientModel(PreTrainedModel):
config_class = VirtualCellPatientConfig
def __init__(self, config: VirtualCellPatientConfig):
super().__init__(config)
cell_embedder = MLPCellEmbedder(
n_genes=config.n_genes,
output_dim=config.embed_dim,
hidden_dim=config.hidden_dim,
dropout=config.dropout,
residual=config.residual,
activation=config.activation,
)
aggregator = AttentionAggregator(
embedding_dim=config.embed_dim,
hidden_dim=config.attention_hidden_dim,
)
self.patient_embedder = PatientEmbedder(cell_embedder, aggregator)
self.classifier = nn.Sequential(
nn.Dropout(config.classifier_dropout),
nn.Linear(config.embed_dim, config.num_classes),
)
self.loss_fn = CrossEntropyLossViews()
def _init_weights(self, module):
pass
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
entity_id: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
"""
Args:
input_ids: [batch, num_cells, num_genes] log-normalized float32 expression
attention_mask: [batch, num_cells] 1=valid, 0=ignore (optional)
labels: [batch] integer class indices (optional, for loss)
entity_id: [batch] patient IDs grouping augmented views (optional)
Returns:
SequenceClassifierOutput with .loss (when labels given) and .logits [batch, num_classes]
"""
embeddings = self.patient_embedder(input_ids, attention_mask)
logits = self.classifier(embeddings)
loss = None
if labels is not None:
loss = (
self.loss_fn(logits, labels, entity_id)
if entity_id is not None
else F.cross_entropy(logits, labels)
)
return SequenceClassifierOutput(loss=loss, logits=logits)