Safetensors
File size: 10,666 Bytes
4527b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Module: classification_heads.py

This module defines various classification and decoder heads for use in transformer-based models,
specifically tailored for single-cell biology tasks. These heads are designed to handle tasks such as
classification, regression, and expression value prediction, and they integrate seamlessly with
transformer architectures.

Main Features:
- **ClsDecoder**: A simple decoder for classification tasks, supporting multiple layers and activations.
- **ClassificationHead**: A RoBERTa-style classification head for downstream tasks.
- **ClassificationHeadAnalysis**: An extended classification head that provides intermediate hidden states for analysis.
- **ClsDecoderAnalysis**: A classification decoder with support for hidden state extraction.
- **TrainingHead**: A dense layer with activation and normalization for training tasks.
- **AnnotationDecoderHead**: A lightweight decoder for annotation tasks with simplified weight initialization.
- **ExprDecoder**: A decoder for predicting gene expression values, with optional explicit zero probability prediction.
- **AffineExprDecoder**: A decoder for predicting gene expression values in an affine form (Ax + b), with support for
  advanced features like adaptive bias and explicit zero probabilities.

Dependencies:
- PyTorch: For defining and training neural network components.
- Transformers: For activation functions and integration with transformer-based models.

Usage:
Import the desired classification or decoder head into your model:
   ```python
   from teddy.models.classification_heads import ClsDecoder, ClassificationHead
   ```
"""

from typing import Dict, Optional

import torch
import torch.nn as nn
from torch import Tensor
from transformers.activations import ACT2FN


class ClsDecoder(nn.Module):  # taken from scGPT. Delete when not needed any more.
    """
    Decoder for classification task.
    """

    def __init__(
        self,
        d_model: int,
        n_cls: int,
        nlayers: int = 1,
        activation: callable = nn.ReLU,
    ):
        super().__init__()
        # module list
        self._decoder = nn.ModuleList()
        for _i in range(nlayers - 1):
            self._decoder.append(nn.Linear(d_model, d_model))
            self._decoder.append(activation())
            self._decoder.append(nn.LayerNorm(d_model))
        self.out_layer = nn.Linear(d_model, n_cls)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, embsize]
        """
        for layer in self._decoder:
            x = layer(x)
        return {"output": self.out_layer(x)}


class ClassificationHead(nn.Module):
    """RoBERTa-style classification head"""

    def __init__(self, config, n_cls, nlayers):
        super().__init__()
        self._decoder = nn.ModuleList()
        self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU()

        for _i in range(nlayers):
            self._decoder.append(nn.Dropout(config.dropout))
            self._decoder.append(nn.Linear(config.d_model, config.d_model))
            self._decoder.append(self.activation)
            self._decoder.append(nn.Dropout(config.dropout))
        self._decoder.append(nn.Linear(config.d_model, n_cls))

    def forward(self, x):
        for module in self._decoder:
            x = module(x)
        return {"output": x}


class ClassificationHeadAnalysis(nn.Module):
    """RoBERTa-style classification head"""

    def __init__(self, config, n_cls, nlayers):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout)
        self._decoder = nn.ModuleList()
        self.activation = nn.ReLU() if config.layer_activation == "relu" else nn.GELU()

        for _i in range(nlayers):
            self._decoder.append(self.dropout)
            self._decoder.append(nn.Linear(config.d_model, config.d_model))
            self._decoder.append(self.activation)
            self._decoder.append(self.dropout)
        self._decoder.append(nn.Linear(config.d_model, n_cls))

    def forward(self, x):
        hidden_states = []
        for module in self._decoder:
            x = module(x)
            if isinstance(module, nn.Linear):
                hidden_states.append(x)
        return {"output": x, "hidden_states": hidden_states}


class ClsDecoderAnalysis(nn.Module):
    """
    Decoder for classification task.
    """

    def __init__(
        self,
        d_model: int,
        n_cls: int,
        nlayers: int = 3,
        activation: callable = nn.ReLU,
    ):
        super().__init__()
        # module list
        self._decoder = nn.ModuleList()
        for _i in range(nlayers - 1):
            self._decoder.append(nn.Linear(d_model, d_model))
            self._decoder.append(activation())
            self._decoder.append(nn.LayerNorm(d_model))
        self.out_layer = nn.Linear(d_model, n_cls)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, embsize]
        """
        hidden_states = []
        for layer in self._decoder:
            x = layer(x)
            hidden_states.append(x)
        return {"output": self.out_layer(x), "hidden_states": hidden_states}


class TrainingHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.d_model, config.d_model)
        self.activation = ACT2FN[config.layer_activation]
        self.LayerNorm = nn.LayerNorm(config.d_model, config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class AnnotationDecoderHead(nn.Linear):
    """Small class to make weight initialization easier"""

    def __init__(self, d_model, n_token):
        super().__init__(d_model, n_token, bias=False)


class ExprDecoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        explicit_zero_prob: bool = False,
        use_batch_labels: bool = False,
    ):
        super().__init__()
        d_in = d_model * 2 if use_batch_labels else d_model
        self.fc = nn.Sequential(
            nn.Linear(d_in, d_model),
            nn.LeakyReLU(),
            nn.Linear(d_model, d_model),
            nn.LeakyReLU(),
            nn.Linear(d_model, 1),
        )
        self.explicit_zero_prob = explicit_zero_prob
        if explicit_zero_prob:
            self.zero_logit = nn.Sequential(
                nn.Linear(d_in, d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, d_model),
                nn.LeakyReLU(),
                nn.Linear(d_model, 1),
            )

    def forward(self, x: Tensor, values: Tensor = None) -> Dict[str, Tensor]:
        """x is the output of the transformer, (batch, seq_len, d_model)"""
        pred_value = self.fc(x).squeeze(-1)  # (batch, seq_len)

        if not self.explicit_zero_prob:
            return {"pred": pred_value}
        zero_logits = self.zero_logit(x).squeeze(-1)  # (batch, seq_len)
        zero_probs = torch.sigmoid(zero_logits)
        return {"pred": pred_value, "zero_probs": zero_probs}
        # TODO: note that the return currently is only for training. Since decoder
        # is not used in the test setting for the integration task, the experiments/inference
        # logic is not implemented yet. However, remember to implement it when
        # the decoder is used in any test setting. The inference logic will need
        # to sample from the bernoulli distribution with the zero_probs.


class AffineExprDecoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        explicit_zero_prob: bool = False,
        activation: Optional[str] = None,
        tanh_coeff: bool = False,
        adaptive_bias: bool = False,
    ):
        """
        Predict the expression value of each gene in an affine like form of Ax + b.
        This decoder takes two ExprDecoder intrinsically to genrate the coefficient A and bias b.

        Args:
            d_model: The embedding dimension.
            explicit_zero_prob: If True, predict the probability of each gene being
                zero.
            activation: The activation function for the coefficient A and bias b.
            tanh_coeff: If True, use tanh activation for the coefficient A.
            adaptive_bias: If True, use a learnable bias for the bias b.
        """
        super().__init__()
        self.explicit_zero_prob = explicit_zero_prob
        self.tanh_coeff = tanh_coeff
        self.adaptive_bias = adaptive_bias
        self.coeff_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob)
        self.bias_decoder = ExprDecoder(d_model, explicit_zero_prob=explicit_zero_prob)
        self.activation = activation

        if activation is not None:
            # Normalize activation name to lowercase for flexibility
            activation = activation.lower()
            # Mapping of known activation functions
            activations_map = {
                "gelu": "GELU",
                "relu": "ReLU",
                "tanh": "Tanh",
                "sigmoid": "Sigmoid",
            }
            assert activation in activations_map, f"Unknown activation: {activation}"
            assert hasattr(nn, activations_map[activation]), f"Unknown activation: {activation}"
            self.activation = getattr(nn, activations_map[activation])()

    def forward(self, x: Tensor, values: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embsize]
            values: Tensor, shape [batch_size, seq_len]

        Returns:
            output Tensor of shape [batch_size, seq_len]
        """
        coeff = self.coeff_decoder(x)
        bias = self.bias_decoder(x)

        if self.activation is not None:
            coeff["pred"] = self.activation(coeff["pred"])
            bias["pred"] = self.activation(bias["pred"])

        # if self.tanh_coeff:
        #     coeff["pred"] = 1 + torch.tanh(coeff["pred"])

        if self.adaptive_bias:
            # bias["pred"] = bias["pred"] * values.mean(dim=1, keepdim=True)
            non_zero_value_mean = values.sum(dim=1, keepdim=True) / (values != 0).sum(dim=1, keepdim=True)
            bias["pred"] = bias["pred"] * non_zero_value_mean

        if self.explicit_zero_prob:
            return {
                "pred": coeff["pred"] * values + bias["pred"],
                "zero_probs": coeff["zero_probs"],
            }

        return {"pred": coeff["pred"] * values + bias["pred"]}