|
|
""" |
|
|
Module: classification_heads.py |
|
|
|
|
|
This module defines various classification and decoder heads for use in transformer-based models, |
|
|
specifically tailored for single-cell biology tasks. These heads are designed to handle tasks such as |
|
|
classification, regression, and expression value prediction, and they integrate seamlessly with |
|
|
transformer architectures. |
|
|
|
|
|
Main Features: |
|
|
- **ClsDecoder**: A simple decoder for classification tasks, supporting multiple layers and activations. |
|
|
- **ClassificationHead**: A RoBERTa-style classification head for downstream tasks. |
|
|
- **ClassificationHeadAnalysis**: An extended classification head that provides intermediate hidden states for analysis. |
|
|
- **ClsDecoderAnalysis**: A classification decoder with support for hidden state extraction. |
|
|
- **TrainingHead**: A dense layer with activation and normalization for training tasks. |
|
|
- **AnnotationDecoderHead**: A lightweight decoder for annotation tasks with simplified weight initialization. |
|
|
- **ExprDecoder**: A decoder for predicting gene expression values, with optional explicit zero probability prediction. |
|
|
- **AffineExprDecoder**: A decoder for predicting gene expression values in an affine form (Ax + b), with support for |
|
|
advanced features like adaptive bias and explicit zero probabilities. |
|
|
|
|
|
Dependencies: |
|
|
- PyTorch: For defining and training neural network components. |
|
|
- Transformers: For activation functions and integration with transformer-based models. |
|
|
|
|
|
Usage: |
|
|
Import the desired classification or decoder head into your model: |
|
|
```python |
|
|
from teddy.models.classification_heads import ClsDecoder, ClassificationHead |
|
|
``` |
|
|
""" |
|
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
|
|
|
class ClsDecoder(nn.Module): |
|
|
""" |
|
|
Decoder for classification task. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_cls: int, |
|
|
nlayers: int = 1, |
|
|
activation: callable = nn.ReLU, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self._decoder = nn.ModuleList() |
|
|
for _i in range(nlayers - 1): |
|
|
self._decoder.append(nn.Linear(d_model, d_model)) |
|
|
self._decoder.append(activation()) |
|
|
self._decoder.append(nn.LayerNorm(d_model)) |
|
|
self.out_layer = nn.Linear(d_model, n_cls) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Tensor, shape [batch_size, embsize] |
|
|
""" |
|
|
for layer in self._decoder: |
|
|
x = layer(x) |
|
|
return {"output": self.out_layer(x)} |
|
|
|
|
|
|
|
|
class ClassificationHead(nn.Module): |
|
|
"""RoBERTa-style classification head""" |
|
|
|
|
|
def __init__(self, config, n_cls, nlayers): |
|
|
super().__init__() |
|
|
self._decoder = nn.ModuleList() |
|
|
self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() |
|
|
|
|
|
for _i in range(nlayers): |
|
|
self._decoder.append(nn.Dropout(config.dropout)) |
|
|
self._decoder.append(nn.Linear(config.d_model, config.d_model)) |
|
|
self._decoder.append(self.activation) |
|
|
self._decoder.append(nn.Dropout(config.dropout)) |
|
|
self._decoder.append(nn.Linear(config.d_model, n_cls)) |
|
|
|
|
|
def forward(self, x): |
|
|
for module in self._decoder: |
|
|
x = module(x) |
|
|
return {"output": x} |
|
|
|
|
|
|
|
|
class ClassificationHeadAnalysis(nn.Module): |
|
|
"""RoBERTa-style classification head""" |
|
|
|
|
|
def __init__(self, config, n_cls, nlayers): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self._decoder = nn.ModuleList() |
|
|
self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU() |
|
|
|
|
|
for _i in range(nlayers): |
|
|
self._decoder.append(self.dropout) |
|
|
self._decoder.append(nn.Linear(config.d_model, config.d_model)) |
|
|
self._decoder.append(self.activation) |
|
|
self._decoder.append(self.dropout) |
|
|
self._decoder.append(nn.Linear(config.d_model, n_cls)) |
|
|
|
|
|
def forward(self, x): |
|
|
hidden_states = [] |
|
|
for module in self._decoder: |
|
|
x = module(x) |
|
|
if isinstance(module, nn.Linear): |
|
|
hidden_states.append(x) |
|
|
return {"output": x, "hidden_states": hidden_states} |
|
|
|
|
|
|
|
|
class ClsDecoderAnalysis(nn.Module): |
|
|
""" |
|
|
Decoder for classification task. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_cls: int, |
|
|
nlayers: int = 3, |
|
|
activation: callable = nn.ReLU, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self._decoder = nn.ModuleList() |
|
|
for _i in range(nlayers - 1): |
|
|
self._decoder.append(nn.Linear(d_model, d_model)) |
|
|
self._decoder.append(activation()) |
|
|
self._decoder.append(nn.LayerNorm(d_model)) |
|
|
self.out_layer = nn.Linear(d_model, n_cls) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Tensor, shape [batch_size, embsize] |
|
|
""" |
|
|
hidden_states = [] |
|
|
for layer in self._decoder: |
|
|
x = layer(x) |
|
|
hidden_states.append(x) |
|
|
return {"output": self.out_layer(x), "hidden_states": hidden_states} |
|
|
|
|
|
|
|
|
class TrainingHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.d_model, config.d_model) |
|
|
self.activation = ACT2FN[config.layer_activation] |
|
|
self.LayerNorm = nn.LayerNorm(config.d_model, config.layer_norm_eps) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.activation(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AnnotationDecoderHead(nn.Linear): |
|
|
"""Small class to make weight initialization easier""" |
|
|
|
|
|
def __init__(self, d_model, n_token): |
|
|
super().__init__(d_model, n_token, bias=False) |
|
|
|
|
|
|
|
|
class ExprDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
explicit_zero_prob: bool = False, |
|
|
use_batch_labels: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
d_in = d_model * 2 if use_batch_labels else d_model |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(d_in, d_model), |
|
|
nn.LeakyReLU(), |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.LeakyReLU(), |
|
|
nn.Linear(d_model, 1), |
|
|
) |
|
|
self.explicit_zero_prob = explicit_zero_prob |
|
|
if explicit_zero_prob: |
|
|
self.zero_logit = nn.Sequential( |
|
|
nn.Linear(d_in, d_model), |
|
|
nn.LeakyReLU(), |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.LeakyReLU(), |
|
|
nn.Linear(d_model, 1), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor, values: Tensor = None) -> Dict[str, Tensor]: |
|
|
"""x is the output of the transformer, (batch, seq_len, d_model)""" |
|
|
pred_value = self.fc(x).squeeze(-1) |
|
|
|
|
|
if not self.explicit_zero_prob: |
|
|
return {"pred": pred_value} |
|
|
zero_logits = self.zero_logit(x).squeeze(-1) |
|
|
zero_probs = torch.sigmoid(zero_logits) |
|
|
return {"pred": pred_value, "zero_probs": zero_probs} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AffineExprDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
explicit_zero_prob: bool = False, |
|
|
activation: Optional[str] = None, |
|
|
tanh_coeff: bool = False, |
|
|
adaptive_bias: bool = False, |
|
|
): |
|
|
""" |
|
|
Predict the expression value of each gene in an affine like form of Ax + b. |
|
|
This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b. |
|
|
|
|
|
Args: |
|
|
d_model: The embedding dimension. |
|
|
explicit_zero_prob: If True, predict the probability of each gene being |
|
|
zero. |
|
|
activation: The activation function for the coefficient A and bias b. |
|
|
tanh_coeff: If True, use tanh activation for the coefficient A. |
|
|
adaptive_bias: If True, use a learnable bias for the bias b. |
|
|
""" |
|
|
super().__init__() |
|
|
self.explicit_zero_prob = explicit_zero_prob |
|
|
self.tanh_coeff = tanh_coeff |
|
|
self.adaptive_bias = adaptive_bias |
|
|
self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) |
|
|
self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob) |
|
|
self.activation = activation |
|
|
|
|
|
if activation is not None: |
|
|
|
|
|
activation = activation.lower() |
|
|
|
|
|
activations_map = { |
|
|
"gelu": "GELU", |
|
|
"relu": "ReLU", |
|
|
"tanh": "Tanh", |
|
|
"sigmoid": "Sigmoid", |
|
|
} |
|
|
assert activation in activations_map, f"Unknown activation: {activation}" |
|
|
assert hasattr(nn, activations_map[activation]), f"Unknown activation: {activation}" |
|
|
self.activation = getattr(nn, activations_map[activation])() |
|
|
|
|
|
def forward(self, x: Tensor, values: Tensor) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Tensor, shape [batch_size, seq_len, embsize] |
|
|
values: Tensor, shape [batch_size, seq_len] |
|
|
|
|
|
Returns: |
|
|
output Tensor of shape [batch_size, seq_len] |
|
|
""" |
|
|
coeff = self.coeff_decoder(x) |
|
|
bias = self.bias_decoder(x) |
|
|
|
|
|
if self.activation is not None: |
|
|
coeff["pred"] = self.activation(coeff["pred"]) |
|
|
bias["pred"] = self.activation(bias["pred"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.adaptive_bias: |
|
|
|
|
|
non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum(dim=1, keepdim=True) |
|
|
bias["pred"] = bias["pred"] * non_zero_value_mean |
|
|
|
|
|
if self.explicit_zero_prob: |
|
|
return { |
|
|
"pred": coeff["pred"] * values + bias["pred"], |
|
|
"zero_probs": coeff["zero_probs"], |
|
|
} |
|
|
|
|
|
return {"pred": coeff["pred"] * values + bias["pred"]} |
|
|
|