""" Module: classification_heads.py This module defines various classification and decoder heads for use in transformer-based models, specifically tailored for single-cell biology tasks. These heads are designed to handle tasks such as classification, regression, and expression value prediction, and they integrate seamlessly with transformer architectures. Main Features: - **ClsDecoder**: A simple decoder for classification tasks, supporting multiple layers and activations. - **ClassificationHead**: A RoBERTa-style classification head for downstream tasks. - **ClassificationHeadAnalysis**: An extended classification head that provides intermediate hidden states for analysis. - **ClsDecoderAnalysis**: A classification decoder with support for hidden state extraction. - **TrainingHead**: A dense layer with activation and normalization for training tasks. - **AnnotationDecoderHead**: A lightweight decoder for annotation tasks with simplified weight initialization. - **ExprDecoder**: A decoder for predicting gene expression values, with optional explicit zero probability prediction. - **AffineExprDecoder**: A decoder for predicting gene expression values in an affine form (Ax + b), with support for advanced features like adaptive bias and explicit zero probabilities. Dependencies: - PyTorch: For defining and training neural network components. - Transformers: For activation functions and integration with transformer-based models. Usage: Import the desired classification or decoder head into your model: ```python from teddy.models.classification_heads import ClsDecoder, ClassificationHead ``` """ from typing import Dict, Optional import torch import torch.nn as nn from torch import Tensor from transformers.activations import ACT2FN class ClsDecoder(nn.Module): # taken from scGPT. Delete when not needed any more. """ Decoder for classification task. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 1, activation: callable = nn.ReLU, ): super().__init__() # module list self._decoder = nn.ModuleList() for _i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, n_cls) def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ for layer in self._decoder: x = layer(x) return {"output": self.out_layer(x)} class ClassificationHead(nn.Module): """RoBERTa-style classification head""" def __init__(self, config, n_cls, nlayers): super().__init__() self._decoder = nn.ModuleList() self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() for _i in range(nlayers): self._decoder.append(nn.Dropout(config.dropout)) self._decoder.append(nn.Linear(config.d_model, config.d_model)) self._decoder.append(self.activation) self._decoder.append(nn.Dropout(config.dropout)) self._decoder.append(nn.Linear(config.d_model, n_cls)) def forward(self, x): for module in self._decoder: x = module(x) return {"output": x} class ClassificationHeadAnalysis(nn.Module): """RoBERTa-style classification head""" def __init__(self, config, n_cls, nlayers): super().__init__() self.dropout = nn.Dropout(config.dropout) self._decoder = nn.ModuleList() self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() for _i in range(nlayers): self._decoder.append(self.dropout) self._decoder.append(nn.Linear(config.d_model, config.d_model)) self._decoder.append(self.activation) self._decoder.append(self.dropout) self._decoder.append(nn.Linear(config.d_model, n_cls)) def forward(self, x): hidden_states = [] for module in self._decoder: x = module(x) if isinstance(module, nn.Linear): hidden_states.append(x) return {"output": x, "hidden_states": hidden_states} class ClsDecoderAnalysis(nn.Module): """ Decoder for classification task. """ def __init__( self, d_model: int, n_cls: int, nlayers: int = 3, activation: callable = nn.ReLU, ): super().__init__() # module list self._decoder = nn.ModuleList() for _i in range(nlayers - 1): self._decoder.append(nn.Linear(d_model, d_model)) self._decoder.append(activation()) self._decoder.append(nn.LayerNorm(d_model)) self.out_layer = nn.Linear(d_model, n_cls) def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, embsize] """ hidden_states = [] for layer in self._decoder: x = layer(x) hidden_states.append(x) return {"output": self.out_layer(x), "hidden_states": hidden_states} class TrainingHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.d_model, config.d_model) self.activation = ACT2FN[config.layer_activation] self.LayerNorm = nn.LayerNorm(config.d_model, config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class AnnotationDecoderHead(nn.Linear): """Small class to make weight initialization easier""" def __init__(self, d_model, n_token): super().__init__(d_model, n_token, bias=False) class ExprDecoder(nn.Module): def __init__( self, d_model: int, explicit_zero_prob: bool = False, use_batch_labels: bool = False, ): super().__init__() d_in = d_model * 2 if use_batch_labels else d_model self.fc = nn.Sequential( nn.Linear(d_in, d_model), nn.LeakyReLU(), nn.Linear(d_model, d_model), nn.LeakyReLU(), nn.Linear(d_model, 1), ) self.explicit_zero_prob = explicit_zero_prob if explicit_zero_prob: self.zero_logit = nn.Sequential( nn.Linear(d_in, d_model), nn.LeakyReLU(), nn.Linear(d_model, d_model), nn.LeakyReLU(), nn.Linear(d_model, 1), ) def forward(self, x: Tensor, values: Tensor = None) -> Dict[str, Tensor]: """x is the output of the transformer, (batch, seq_len, d_model)""" pred_value = self.fc(x).squeeze(-1) # (batch, seq_len) if not self.explicit_zero_prob: return {"pred": pred_value} zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len) zero_probs = torch.sigmoid(zero_logits) return {"pred": pred_value, "zero_probs": zero_probs} # TODO: note that the return currently is only for training. Since decoder # is not used in the test setting for the integration task, the experiments/inference # logic is not implemented yet. However, remember to implement it when # the decoder is used in any test setting. The inference logic will need # to sample from the bernoulli distribution with the zero_probs. class AffineExprDecoder(nn.Module): def __init__( self, d_model: int, explicit_zero_prob: bool = False, activation: Optional[str] = None, tanh_coeff: bool = False, adaptive_bias: bool = False, ): """ Predict the expression value of each gene in an affine like form of Ax + b. This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b. Args: d_model: The embedding dimension. explicit_zero_prob: If True, predict the probability of each gene being zero. activation: The activation function for the coefficient A and bias b. tanh_coeff: If True, use tanh activation for the coefficient A. adaptive_bias: If True, use a learnable bias for the bias b. """ super().__init__() self.explicit_zero_prob = explicit_zero_prob self.tanh_coeff = tanh_coeff self.adaptive_bias = adaptive_bias self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) self.activation = activation if activation is not None: # Normalize activation name to lowercase for flexibility activation = activation.lower() # Mapping of known activation functions activations_map = { "gelu": "GELU", "relu": "ReLU", "tanh": "Tanh", "sigmoid": "Sigmoid", } assert activation in activations_map, f"Unknown activation: {activation}" assert hasattr(nn, activations_map[activation]), f"Unknown activation: {activation}" self.activation = getattr(nn, activations_map[activation])() def forward(self, x: Tensor, values: Tensor) -> Tensor: """ Args: x: Tensor, shape [batch_size, seq_len, embsize] values: Tensor, shape [batch_size, seq_len] Returns: output Tensor of shape [batch_size, seq_len] """ coeff = self.coeff_decoder(x) bias = self.bias_decoder(x) if self.activation is not None: coeff["pred"] = self.activation(coeff["pred"]) bias["pred"] = self.activation(bias["pred"]) # if self.tanh_coeff: # coeff["pred"] = 1 + torch.tanh(coeff["pred"]) if self.adaptive_bias: # bias["pred"] = bias["pred"] * values.mean(dim=1, keepdim=True) non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum(dim=1, keepdim=True) bias["pred"] = bias["pred"] * non_zero_value_mean if self.explicit_zero_prob: return { "pred": coeff["pred"] * values + bias["pred"], "zero_probs": coeff["zero_probs"], } return {"pred": coeff["pred"] * values + bias["pred"]}