File size: 8,186 Bytes
96c02e7
 
 
 
 
 
 
b4ea423
96c02e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4ea423
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# modeling_scBloodClassifier.py
import os
from typing import List, Dict, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import AutoConfig, AutoModel


class MLPBlock(nn.Module):
    """Single MLP block with optional residual connection."""

    def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
        super().__init__()
        self.use_residual = use_residual and (input_dim == output_dim)
        self.linear = nn.Linear(input_dim, output_dim)
        self.bn = nn.BatchNorm1d(output_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        x = self.linear(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.dropout(x)
        if self.use_residual:
            x = x + identity
        return x


class MLPClassifier(nn.Module):
    """MLP classifier with multiple hidden layers and optional residual connections."""

    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        dropout_rate: float = 0.2,
        use_residual_in_hidden: bool = True,
        loss_fn: Optional[nn.Module] = None
    ):
        super().__init__()
        self.initial_bn = nn.BatchNorm1d(input_dim)

        all_dims = [input_dim] + hidden_dims
        layers = [
            MLPBlock(
                input_dim=all_dims[i],
                output_dim=all_dims[i + 1],
                dropout_rate=dropout_rate,
                use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
            )
            for i in range(len(all_dims) - 1)
        ]
        self.hidden_network = nn.Sequential(*layers)
        self.output_projection = nn.Linear(all_dims[-1], output_dim)
        self.loss_fn = loss_fn or nn.CrossEntropyLoss()

        self._initialize_weights()

    def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True):
        if x.ndim > 2:
            x = x.view(x.size(0), -1)
        x = self.initial_bn(x)
        x = self.hidden_network(x)
        logits = self.output_projection(x)
        loss = self.loss_fn(logits, labels) if labels is not None else None

        if not return_dict:
            return (logits, loss) if loss is not None else (logits,)
        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=None, attentions=None)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class scBloodClassifierConfig(PretrainedConfig):
    """Configuration for scBloodClassifier."""

    model_type = "scBloodClassifier"

    def __init__(
        self,
        sub_classifier_names: Optional[List[str]] = None,
        main_classifier_config: Optional[Dict] = None,
        sub_classifiers_config: Optional[Dict] = None,
        main_labels: Optional[Dict] = None,
        sub_labels: Optional[Dict] = None,
        macro_to_sub: Optional[Dict] = None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.sub_classifier_names = sub_classifier_names or []
        self.main_classifier_config = main_classifier_config or {}
        self.sub_classifiers_config = sub_classifiers_config or {}
        self.main_labels = main_labels or {}
        self.sub_labels = sub_labels or {}
        self.macro_to_sub = macro_to_sub or {}


class scBloodClassifier(PreTrainedModel):
    """Hierarchical classifier for single-cell RNA-seq blood data."""

    config_class = scBloodClassifierConfig

    def __init__(self, config: scBloodClassifierConfig):
        super().__init__(config)
        self.config = config

        # Main classifier
        self.main_classifier = self._create_classifier(config.main_classifier_config)

        # Sub-classifiers
        self.sub_classifiers = nn.ModuleDict({
            name: self._create_classifier(config.sub_classifiers_config.get(name, {}))
            for name in config.sub_classifier_names
        })

        # Label mappings
        self.main_labels = dict(config.main_labels)
        self.sub_labels = dict(config.sub_labels)
        self.macro_to_sub = dict(config.macro_to_sub)

        self.post_init()  # required by transformers

    def _create_classifier(self, cfg: Dict) -> MLPClassifier:
        return MLPClassifier(
            input_dim=cfg['input_dim'],
            hidden_dims=cfg.get('hidden_dims', []),
            output_dim=cfg['output_dim'],
            dropout_rate=cfg.get('dropout_rate', 0.2),
            use_residual_in_hidden=cfg.get('use_residual_in_hidden', True)
        )

    def forward(self, x: torch.Tensor, return_dict: bool = True, **kwargs):
        """Return logits of the main classifier."""
        return self.main_classifier(x, return_dict=return_dict)

    def predict_labels(self, x: torch.Tensor, return_probabilities: bool = False) -> Dict[str, any]:
        """Predict hierarchical labels for a batch of inputs."""
        self.eval()
        with torch.no_grad():
            main_out = self.main_classifier(x, return_dict=True)
            main_logits = main_out.logits
            main_probs = torch.softmax(main_logits, dim=-1)
            main_pred = torch.argmax(main_logits, dim=-1)

            final_predictions = []
            sub_probs_list = [] if return_probabilities else None

            for i in range(x.shape[0]):
                macro_idx = str(int(main_pred[i].item()))
                macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")

                # Check for sub-classifier
                if macro_idx in self.macro_to_sub:
                    sub_name = self.macro_to_sub[macro_idx]
                    if sub_name in self.sub_classifiers:
                        sub_out = self.sub_classifiers[sub_name](x[i:i+1], return_dict=True)
                        sub_logits = sub_out.logits
                        sub_pred = torch.argmax(sub_logits, dim=-1)
                        sub_idx = str(int(sub_pred.item()))
                        sub_label = self.sub_labels.get(sub_name, {}).get(sub_idx, f"unknown_{sub_idx}")
                        final_label = f"{macro_label}_{sub_label}"
                        if return_probabilities:
                            sub_probs_list.append(torch.softmax(sub_logits, dim=-1)[0])
                    else:
                        final_label = macro_label
                        if return_probabilities:
                            sub_probs_list.append(None)
                else:
                    final_label = macro_label
                    if return_probabilities:
                        sub_probs_list.append(None)

                final_predictions.append(final_label)

        out = {"final_predictions": final_predictions}
        if return_probabilities:
            out["macro_probabilities"] = main_probs
            out["sub_probabilities"] = sub_probs_list
        return out

    def save_pretrained(self, save_directory: str):
        """Save model and config in Hugging Face format."""
        os.makedirs(save_directory, exist_ok=True)
        self.config.main_labels = self.main_labels
        self.config.sub_labels = self.sub_labels
        self.config.macro_to_sub = self.macro_to_sub
        super().save_pretrained(save_directory)
        # Optional README
        readme_path = os.path.join(save_directory, "README.md")
        if not os.path.exists(readme_path):
            with open(readme_path, "w") as f:
                f.write("# scBloodClassifier\nSaved model and config.")



AutoConfig.register("scBloodClassifier", scBloodClassifierConfig)

AutoModel.register(scBloodClassifierConfig, scBloodClassifier)