Safetensors
TEDDY / teddy /models /classification_heads.py
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
"""
Module: classification_heads.py
This module defines various classification and decoder heads for use in transformer-based models,
specifically tailored for single-cell biology tasks. These heads are designed to handle tasks such as
classification, regression, and expression value prediction, and they integrate seamlessly with
transformer architectures.
Main Features:
- **ClsDecoder**: A simple decoder for classification tasks, supporting multiple layers and activations.
- **ClassificationHead**: A RoBERTa-style classification head for downstream tasks.
- **ClassificationHeadAnalysis**: An extended classification head that provides intermediate hidden states for analysis.
- **ClsDecoderAnalysis**: A classification decoder with support for hidden state extraction.
- **TrainingHead**: A dense layer with activation and normalization for training tasks.
- **AnnotationDecoderHead**: A lightweight decoder for annotation tasks with simplified weight initialization.
- **ExprDecoder**: A decoder for predicting gene expression values, with optional explicit zero probability prediction.
- **AffineExprDecoder**: A decoder for predicting gene expression values in an affine form (Ax + b), with support for
advanced features like adaptive bias and explicit zero probabilities.
Dependencies:
- PyTorch: For defining and training neural network components.
- Transformers: For activation functions and integration with transformer-based models.
Usage:
Import the desired classification or decoder head into your model:
```python
from teddy.models.classification_heads import ClsDecoder, ClassificationHead
```
"""
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from transformers.activations import ACT2FN
class ClsDecoder(nn.Module): # taken from scGPT. Delete when not needed any more.
"""
Decoder for classification task.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 1,
activation: callable = nn.ReLU,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for _i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, n_cls)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
for layer in self._decoder:
x = layer(x)
return {"output": self.out_layer(x)}
class ClassificationHead(nn.Module):
"""RoBERTa-style classification head"""
def __init__(self, config, n_cls, nlayers):
super().__init__()
self._decoder = nn.ModuleList()
self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU()
for _i in range(nlayers):
self._decoder.append(nn.Dropout(config.dropout))
self._decoder.append(nn.Linear(config.d_model, config.d_model))
self._decoder.append(self.activation)
self._decoder.append(nn.Dropout(config.dropout))
self._decoder.append(nn.Linear(config.d_model, n_cls))
def forward(self, x):
for module in self._decoder:
x = module(x)
return {"output": x}
class ClassificationHeadAnalysis(nn.Module):
"""RoBERTa-style classification head"""
def __init__(self, config, n_cls, nlayers):
super().__init__()
self.dropout = nn.Dropout(config.dropout)
self._decoder = nn.ModuleList()
self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU()
for _i in range(nlayers):
self._decoder.append(self.dropout)
self._decoder.append(nn.Linear(config.d_model, config.d_model))
self._decoder.append(self.activation)
self._decoder.append(self.dropout)
self._decoder.append(nn.Linear(config.d_model, n_cls))
def forward(self, x):
hidden_states = []
for module in self._decoder:
x = module(x)
if isinstance(module, nn.Linear):
hidden_states.append(x)
return {"output": x, "hidden_states": hidden_states}
class ClsDecoderAnalysis(nn.Module):
"""
Decoder for classification task.
"""
def __init__(
self,
d_model: int,
n_cls: int,
nlayers: int = 3,
activation: callable = nn.ReLU,
):
super().__init__()
# module list
self._decoder = nn.ModuleList()
for _i in range(nlayers - 1):
self._decoder.append(nn.Linear(d_model, d_model))
self._decoder.append(activation())
self._decoder.append(nn.LayerNorm(d_model))
self.out_layer = nn.Linear(d_model, n_cls)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, embsize]
"""
hidden_states = []
for layer in self._decoder:
x = layer(x)
hidden_states.append(x)
return {"output": self.out_layer(x), "hidden_states": hidden_states}
class TrainingHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.d_model, config.d_model)
self.activation = ACT2FN[config.layer_activation]
self.LayerNorm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class AnnotationDecoderHead(nn.Linear):
"""Small class to make weight initialization easier"""
def __init__(self, d_model, n_token):
super().__init__(d_model, n_token, bias=False)
class ExprDecoder(nn.Module):
def __init__(
self,
d_model: int,
explicit_zero_prob: bool = False,
use_batch_labels: bool = False,
):
super().__init__()
d_in = d_model * 2 if use_batch_labels else d_model
self.fc = nn.Sequential(
nn.Linear(d_in, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, 1),
)
self.explicit_zero_prob = explicit_zero_prob
if explicit_zero_prob:
self.zero_logit = nn.Sequential(
nn.Linear(d_in, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, d_model),
nn.LeakyReLU(),
nn.Linear(d_model, 1),
)
def forward(self, x: Tensor, values: Tensor = None) -> Dict[str, Tensor]:
"""x is the output of the transformer, (batch, seq_len, d_model)"""
pred_value = self.fc(x).squeeze(-1) # (batch, seq_len)
if not self.explicit_zero_prob:
return {"pred": pred_value}
zero_logits = self.zero_logit(x).squeeze(-1) # (batch, seq_len)
zero_probs = torch.sigmoid(zero_logits)
return {"pred": pred_value, "zero_probs": zero_probs}
# TODO: note that the return currently is only for training. Since decoder
# is not used in the test setting for the integration task, the experiments/inference
# logic is not implemented yet. However, remember to implement it when
# the decoder is used in any test setting. The inference logic will need
# to sample from the bernoulli distribution with the zero_probs.
class AffineExprDecoder(nn.Module):
def __init__(
self,
d_model: int,
explicit_zero_prob: bool = False,
activation: Optional[str] = None,
tanh_coeff: bool = False,
adaptive_bias: bool = False,
):
"""
Predict the expression value of each gene in an affine like form of Ax + b.
This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b.
Args:
d_model: The embedding dimension.
explicit_zero_prob: If True, predict the probability of each gene being
zero.
activation: The activation function for the coefficient A and bias b.
tanh_coeff: If True, use tanh activation for the coefficient A.
adaptive_bias: If True, use a learnable bias for the bias b.
"""
super().__init__()
self.explicit_zero_prob = explicit_zero_prob
self.tanh_coeff = tanh_coeff
self.adaptive_bias = adaptive_bias
self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob)
self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob)
self.activation = activation
if activation is not None:
# Normalize activation name to lowercase for flexibility
activation = activation.lower()
# Mapping of known activation functions
activations_map = {
"gelu": "GELU",
"relu": "ReLU",
"tanh": "Tanh",
"sigmoid": "Sigmoid",
}
assert activation in activations_map, f"Unknown activation: {activation}"
assert hasattr(nn, activations_map[activation]), f"Unknown activation: {activation}"
self.activation = getattr(nn, activations_map[activation])()
def forward(self, x: Tensor, values: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, seq_len, embsize]
values: Tensor, shape [batch_size, seq_len]
Returns:
output Tensor of shape [batch_size, seq_len]
"""
coeff = self.coeff_decoder(x)
bias = self.bias_decoder(x)
if self.activation is not None:
coeff["pred"] = self.activation(coeff["pred"])
bias["pred"] = self.activation(bias["pred"])
# if self.tanh_coeff:
# coeff["pred"] = 1 + torch.tanh(coeff["pred"])
if self.adaptive_bias:
# bias["pred"] = bias["pred"] * values.mean(dim=1, keepdim=True)
non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum(dim=1, keepdim=True)
bias["pred"] = bias["pred"] * non_zero_value_mean
if self.explicit_zero_prob:
return {
"pred": coeff["pred"] * values + bias["pred"],
"zero_probs": coeff["zero_probs"],
}
return {"pred": coeff["pred"] * values + bias["pred"]}