vincenzocivale commited on
Commit
96c02e7
·
1 Parent(s): d37b3a2

Add: first version of model

Browse files
Files changed (4) hide show
  1. README.md +2 -0
  2. config.json +179 -0
  3. model.safetensors +3 -0
  4. modeling_unified.py +204 -0
README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # UnifiedCellClassifier
2
+ Saved model and config.
config.json ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UnifiedCellClassifier"
4
+ ],
5
+ "dtype": "float32",
6
+ "macro_to_sub": {
7
+ "0": "B_cells_classifier",
8
+ "1": "CD4plus_T_cells_classifier",
9
+ "4": "Myeloid_cells_classifier",
10
+ "5": "NK_cells_classifier",
11
+ "7": "TRAV1_2_CD8plus_T_cells_classifier",
12
+ "8": "gd_T_cells_classfier"
13
+ },
14
+ "main_classifier_config": {
15
+ "dropout_rate": 0.2,
16
+ "hidden_dims": [
17
+ 512,
18
+ 256
19
+ ],
20
+ "input_dim": 3072,
21
+ "output_dim": 9,
22
+ "use_residual_in_hidden": true
23
+ },
24
+ "main_labels": {
25
+ "0": "B cells",
26
+ "1": "CD4+ T cells",
27
+ "2": "DN T cells",
28
+ "3": "MAIT cells",
29
+ "4": "Myeloid cells",
30
+ "5": "NK cells",
31
+ "6": "Progenitor cells",
32
+ "7": "TRAV1-2- CD8+ T cells",
33
+ "8": "gd T cells"
34
+ },
35
+ "model_type": "unified-cell-classifier",
36
+ "sub_classifier_names": [
37
+ "B_cells_classifier",
38
+ "CD4plus_T_cells_classifier",
39
+ "Myeloid_cells_classifier",
40
+ "NK_cells_classifier",
41
+ "TRAV1_2_CD8plus_T_cells_classifier",
42
+ "gd_T_cells_classfier"
43
+ ],
44
+ "sub_classifiers_config": {
45
+ "B_cells_classifier": {
46
+ "dropout_rate": 0.2,
47
+ "hidden_dims": [
48
+ 3072,
49
+ 1536,
50
+ 768
51
+ ],
52
+ "input_dim": 3072,
53
+ "output_dim": 9,
54
+ "use_residual_in_hidden": true
55
+ },
56
+ "CD4plus_T_cells_classifier": {
57
+ "dropout_rate": 0.2,
58
+ "hidden_dims": [
59
+ 3072,
60
+ 1536,
61
+ 768
62
+ ],
63
+ "input_dim": 3072,
64
+ "output_dim": 16,
65
+ "use_residual_in_hidden": true
66
+ },
67
+ "Myeloid_cells_classifier": {
68
+ "dropout_rate": 0.2,
69
+ "hidden_dims": [
70
+ 3072,
71
+ 1536,
72
+ 768
73
+ ],
74
+ "input_dim": 3072,
75
+ "output_dim": 4,
76
+ "use_residual_in_hidden": true
77
+ },
78
+ "NK_cells_classifier": {
79
+ "dropout_rate": 0.2,
80
+ "hidden_dims": [
81
+ 3072,
82
+ 1536,
83
+ 768
84
+ ],
85
+ "input_dim": 3072,
86
+ "output_dim": 6,
87
+ "use_residual_in_hidden": true
88
+ },
89
+ "TRAV1_2_CD8plus_T_cells_classifier": {
90
+ "dropout_rate": 0.2,
91
+ "hidden_dims": [
92
+ 3072,
93
+ 1536,
94
+ 768
95
+ ],
96
+ "input_dim": 3072,
97
+ "output_dim": 12,
98
+ "use_residual_in_hidden": true
99
+ },
100
+ "gd_T_cells_classfier": {
101
+ "dropout_rate": 0.2,
102
+ "hidden_dims": [
103
+ 3072,
104
+ 1536,
105
+ 768
106
+ ],
107
+ "input_dim": 3072,
108
+ "output_dim": 5,
109
+ "use_residual_in_hidden": true
110
+ }
111
+ },
112
+ "sub_labels": {
113
+ "B_cells_classifier": {
114
+ "0": "Activated",
115
+ "1": "Atypical memory",
116
+ "2": "CD5+ B cells",
117
+ "3": "Naive",
118
+ "4": "Naive-IFN",
119
+ "5": "Non-switched memory",
120
+ "6": "Plasma cells",
121
+ "7": "Switched memory",
122
+ "8": "Transitional"
123
+ },
124
+ "CD4plus_T_cells_classifier": {
125
+ "0": "Exhausted-like memory",
126
+ "1": "HLA-DR+ memory",
127
+ "10": "Th2",
128
+ "11": "Th22",
129
+ "12": "Treg KLRB1+RORC+",
130
+ "13": "Treg cytotoxic",
131
+ "14": "Treg memory",
132
+ "15": "Treg naive",
133
+ "2": "Naive",
134
+ "3": "Naive-IFN",
135
+ "4": "Temra",
136
+ "5": "Terminal effector",
137
+ "6": "Tfh",
138
+ "7": "Th1",
139
+ "8": "Th1/Th17",
140
+ "9": "Th17"
141
+ },
142
+ "Myeloid_cells_classifier": {
143
+ "0": "Classical monocytes",
144
+ "1": "Non-classical monocytes",
145
+ "2": "cDCs",
146
+ "3": "pDCs"
147
+ },
148
+ "NK_cells_classifier": {
149
+ "0": "CD56bright",
150
+ "1": "CD56dim CD57+",
151
+ "2": "CD56dim CD57-",
152
+ "3": "CD56dim CD57int",
153
+ "4": "CD56dim CD57low",
154
+ "5": "Proliferative"
155
+ },
156
+ "TRAV1_2_CD8plus_T_cells_classifier": {
157
+ "0": "HLA-DR+",
158
+ "1": "NKT-like",
159
+ "10": "Tmem KLRC2+",
160
+ "11": "Trm",
161
+ "2": "Naive",
162
+ "3": "Naive-IFN",
163
+ "4": "Proliferative",
164
+ "5": "Tcm CCR4+",
165
+ "6": "Tcm CCR4-",
166
+ "7": "Tem GZMB+",
167
+ "8": "Tem GZMK+",
168
+ "9": "Temra"
169
+ },
170
+ "gd_T_cells_classfier": {
171
+ "0": "Vd1 GZMB+",
172
+ "1": "Vd1 GZMK+",
173
+ "2": "Vd2 GZMB+",
174
+ "3": "Vd2 GZMK+",
175
+ "4": "gd naive"
176
+ }
177
+ },
178
+ "transformers_version": "4.56.1"
179
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0adfc1d920ec2aefefbc62b6b1d0015331fb34bf0984332405e23e13ce3f66a7
3
+ size 376065164
modeling_unified.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_scBloodClassifier.py
2
+ import os
3
+ from typing import List, Dict, Optional
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+ from transformers.modeling_outputs import SequenceClassifierOutput
8
+
9
+
10
+ class MLPBlock(nn.Module):
11
+ """Single MLP block with optional residual connection."""
12
+
13
+ def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
14
+ super().__init__()
15
+ self.use_residual = use_residual and (input_dim == output_dim)
16
+ self.linear = nn.Linear(input_dim, output_dim)
17
+ self.bn = nn.BatchNorm1d(output_dim)
18
+ self.activation = nn.GELU()
19
+ self.dropout = nn.Dropout(dropout_rate)
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ identity = x
23
+ x = self.linear(x)
24
+ x = self.bn(x)
25
+ x = self.activation(x)
26
+ x = self.dropout(x)
27
+ if self.use_residual:
28
+ x = x + identity
29
+ return x
30
+
31
+
32
+ class MLPClassifier(nn.Module):
33
+ """MLP classifier with multiple hidden layers and optional residual connections."""
34
+
35
+ def __init__(
36
+ self,
37
+ input_dim: int,
38
+ hidden_dims: List[int],
39
+ output_dim: int,
40
+ dropout_rate: float = 0.2,
41
+ use_residual_in_hidden: bool = True,
42
+ loss_fn: Optional[nn.Module] = None
43
+ ):
44
+ super().__init__()
45
+ self.initial_bn = nn.BatchNorm1d(input_dim)
46
+
47
+ all_dims = [input_dim] + hidden_dims
48
+ layers = [
49
+ MLPBlock(
50
+ input_dim=all_dims[i],
51
+ output_dim=all_dims[i + 1],
52
+ dropout_rate=dropout_rate,
53
+ use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
54
+ )
55
+ for i in range(len(all_dims) - 1)
56
+ ]
57
+ self.hidden_network = nn.Sequential(*layers)
58
+ self.output_projection = nn.Linear(all_dims[-1], output_dim)
59
+ self.loss_fn = loss_fn or nn.CrossEntropyLoss()
60
+
61
+ self._initialize_weights()
62
+
63
+ def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True):
64
+ if x.ndim > 2:
65
+ x = x.view(x.size(0), -1)
66
+ x = self.initial_bn(x)
67
+ x = self.hidden_network(x)
68
+ logits = self.output_projection(x)
69
+ loss = self.loss_fn(logits, labels) if labels is not None else None
70
+
71
+ if not return_dict:
72
+ return (logits, loss) if loss is not None else (logits,)
73
+ return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=None, attentions=None)
74
+
75
+ def _initialize_weights(self):
76
+ for m in self.modules():
77
+ if isinstance(m, nn.Linear):
78
+ nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
79
+ if m.bias is not None:
80
+ nn.init.zeros_(m.bias)
81
+ elif isinstance(m, nn.BatchNorm1d):
82
+ nn.init.constant_(m.weight, 1)
83
+ nn.init.constant_(m.bias, 0)
84
+
85
+
86
+ class scBloodClassifierConfig(PretrainedConfig):
87
+ """Configuration for scBloodClassifier."""
88
+
89
+ model_type = "scBloodClassifier"
90
+
91
+ def __init__(
92
+ self,
93
+ sub_classifier_names: Optional[List[str]] = None,
94
+ main_classifier_config: Optional[Dict] = None,
95
+ sub_classifiers_config: Optional[Dict] = None,
96
+ main_labels: Optional[Dict] = None,
97
+ sub_labels: Optional[Dict] = None,
98
+ macro_to_sub: Optional[Dict] = None,
99
+ **kwargs
100
+ ):
101
+ super().__init__(**kwargs)
102
+ self.sub_classifier_names = sub_classifier_names or []
103
+ self.main_classifier_config = main_classifier_config or {}
104
+ self.sub_classifiers_config = sub_classifiers_config or {}
105
+ self.main_labels = main_labels or {}
106
+ self.sub_labels = sub_labels or {}
107
+ self.macro_to_sub = macro_to_sub or {}
108
+
109
+
110
+ class scBloodClassifier(PreTrainedModel):
111
+ """Hierarchical classifier for single-cell RNA-seq blood data."""
112
+
113
+ config_class = scBloodClassifierConfig
114
+
115
+ def __init__(self, config: scBloodClassifierConfig):
116
+ super().__init__(config)
117
+ self.config = config
118
+
119
+ # Main classifier
120
+ self.main_classifier = self._create_classifier(config.main_classifier_config)
121
+
122
+ # Sub-classifiers
123
+ self.sub_classifiers = nn.ModuleDict({
124
+ name: self._create_classifier(config.sub_classifiers_config.get(name, {}))
125
+ for name in config.sub_classifier_names
126
+ })
127
+
128
+ # Label mappings
129
+ self.main_labels = dict(config.main_labels)
130
+ self.sub_labels = dict(config.sub_labels)
131
+ self.macro_to_sub = dict(config.macro_to_sub)
132
+
133
+ self.post_init() # required by transformers
134
+
135
+ def _create_classifier(self, cfg: Dict) -> MLPClassifier:
136
+ return MLPClassifier(
137
+ input_dim=cfg['input_dim'],
138
+ hidden_dims=cfg.get('hidden_dims', []),
139
+ output_dim=cfg['output_dim'],
140
+ dropout_rate=cfg.get('dropout_rate', 0.2),
141
+ use_residual_in_hidden=cfg.get('use_residual_in_hidden', True)
142
+ )
143
+
144
+ def forward(self, x: torch.Tensor, return_dict: bool = True, **kwargs):
145
+ """Return logits of the main classifier."""
146
+ return self.main_classifier(x, return_dict=return_dict)
147
+
148
+ def predict_labels(self, x: torch.Tensor, return_probabilities: bool = False) -> Dict[str, any]:
149
+ """Predict hierarchical labels for a batch of inputs."""
150
+ self.eval()
151
+ with torch.no_grad():
152
+ main_out = self.main_classifier(x, return_dict=True)
153
+ main_logits = main_out.logits
154
+ main_probs = torch.softmax(main_logits, dim=-1)
155
+ main_pred = torch.argmax(main_logits, dim=-1)
156
+
157
+ final_predictions = []
158
+ sub_probs_list = [] if return_probabilities else None
159
+
160
+ for i in range(x.shape[0]):
161
+ macro_idx = str(int(main_pred[i].item()))
162
+ macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
163
+
164
+ # Check for sub-classifier
165
+ if macro_idx in self.macro_to_sub:
166
+ sub_name = self.macro_to_sub[macro_idx]
167
+ if sub_name in self.sub_classifiers:
168
+ sub_out = self.sub_classifiers[sub_name](x[i:i+1], return_dict=True)
169
+ sub_logits = sub_out.logits
170
+ sub_pred = torch.argmax(sub_logits, dim=-1)
171
+ sub_idx = str(int(sub_pred.item()))
172
+ sub_label = self.sub_labels.get(sub_name, {}).get(sub_idx, f"unknown_{sub_idx}")
173
+ final_label = f"{macro_label}_{sub_label}"
174
+ if return_probabilities:
175
+ sub_probs_list.append(torch.softmax(sub_logits, dim=-1)[0])
176
+ else:
177
+ final_label = macro_label
178
+ if return_probabilities:
179
+ sub_probs_list.append(None)
180
+ else:
181
+ final_label = macro_label
182
+ if return_probabilities:
183
+ sub_probs_list.append(None)
184
+
185
+ final_predictions.append(final_label)
186
+
187
+ out = {"final_predictions": final_predictions}
188
+ if return_probabilities:
189
+ out["macro_probabilities"] = main_probs
190
+ out["sub_probabilities"] = sub_probs_list
191
+ return out
192
+
193
+ def save_pretrained(self, save_directory: str):
194
+ """Save model and config in Hugging Face format."""
195
+ os.makedirs(save_directory, exist_ok=True)
196
+ self.config.main_labels = self.main_labels
197
+ self.config.sub_labels = self.sub_labels
198
+ self.config.macro_to_sub = self.macro_to_sub
199
+ super().save_pretrained(save_directory)
200
+ # Optional README
201
+ readme_path = os.path.join(save_directory, "README.md")
202
+ if not os.path.exists(readme_path):
203
+ with open(readme_path, "w") as f:
204
+ f.write("# scBloodClassifier\nSaved model and config.")