Yuto2007 commited on
Commit
8cd3669
·
verified ·
1 Parent(s): 5314854

Update unified_cell_classifier.py

Browse files
Files changed (1) hide show
  1. unified_cell_classifier.py +249 -163
unified_cell_classifier.py CHANGED
@@ -1,163 +1,249 @@
1
- import torch
2
- import torch.nn as nn
3
- import json
4
- import os
5
- from typing import Dict, Optional, Tuple
6
-
7
- class UnifiedCellClassifier(nn.Module):
8
- def __init__(self, models_base_path: str, sub_classifier_names: list):
9
- """
10
- Args:
11
- models_base_path: Path base dove sono salvati i modelli
12
- sub_classifier_names: Lista nomi cartelle sub-classificatori
13
- Struttura attesa:
14
- - main_classifier/model.pth + id2label.json
15
- - B_cells_classifier/model.pth + id2label.json
16
- - T_cells_classifier/model.pth + id2label.json
17
- - ...
18
- """
19
- super().__init__()
20
-
21
- # Carica classificatore principale
22
- main_path = os.path.join(models_base_path, "main_classifier")
23
- self.main_classifier = torch.load(os.path.join(main_path, "complete_model.pth"), map_location='cpu', weights_only=False)
24
- with open(os.path.join(main_path, "id2label.json")) as f:
25
- self.main_labels = json.load(f)
26
-
27
-
28
- # Carica sub-classificatori
29
- self.sub_classifiers = nn.ModuleDict()
30
- self.sub_labels = {}
31
-
32
- for sub_name in sub_classifier_names:
33
- sub_path = os.path.join(models_base_path, sub_name)
34
-
35
- if os.path.exists(sub_path):
36
- self.sub_classifiers[sub_name] = torch.load(
37
- os.path.join(sub_path, "complete_model.pth"), map_location='cpu', weights_only=False
38
- )
39
- with open(os.path.join(sub_path, "id2label.json")) as f:
40
- self.sub_labels[sub_name] = json.load(f)
41
-
42
- # Mapping macrocategoria -> sub-classificatore (da configurare)
43
- self.macro_to_sub = self._build_macro_to_sub_mapping()
44
-
45
- # Imposta modalità eval
46
- self.eval()
47
-
48
- def _build_macro_to_sub_mapping(self):
49
-
50
- return {
51
- "0": "B_cells_classifier",
52
- "1": "CD4plus_T_cells_classifier",
53
- "4": "Myeloid_cells_classifier",
54
- "5": "NK_cells_classifier",
55
- "7": "TRAV1_2_CD8plus_T_cells",
56
- "8": "gd_T_cells_classfier"
57
- }
58
-
59
- def forward(self, x: torch.Tensor, return_probabilities: bool = False):
60
- """
61
- Forward pass gerarchico
62
-
63
- Args:
64
- x: Input embeddings [batch_size, embedding_dim]
65
- return_probabilities: Se True, restituisce anche le probabilità
66
-
67
- Returns:
68
- Dict con macro_prediction, sub_prediction, final_prediction
69
- """
70
- # Classificazione principale
71
- with torch.no_grad():
72
- main_logits = self.main_classifier(x)
73
- main_probs = torch.softmax(main_logits, dim=-1)
74
- main_pred = torch.argmax(main_logits, dim=-1)
75
-
76
- # Classificazione secondaria
77
- batch_size = x.shape[0]
78
- sub_predictions = []
79
- sub_probabilities = [] if return_probabilities else None
80
-
81
- for i in range(batch_size):
82
- macro_idx = str(main_pred[i].item())
83
- macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
84
-
85
- # Controlla se esiste sub-classificatore per questa macro
86
- if macro_idx in self.macro_to_sub:
87
- sub_classifier_name = self.macro_to_sub[macro_idx]
88
- if sub_classifier_name in self.sub_classifiers:
89
- # Usa sub-classificatore
90
- with torch.no_grad():
91
- sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
92
- sub_probs = torch.softmax(sub_logits, dim=-1)
93
- sub_pred = torch.argmax(sub_logits, dim=-1)
94
-
95
- sub_idx = str(sub_pred.item())
96
- sub_label = self.sub_labels[sub_classifier_name].get(sub_idx, f"unknown_{sub_idx}")
97
- final_prediction = f"{macro_label}_{sub_label}"
98
-
99
- if return_probabilities:
100
- sub_probabilities.append(sub_probs[0])
101
- else:
102
- # Sub-classificatore non trovato, usa solo macro
103
- final_prediction = macro_label
104
- if return_probabilities:
105
- sub_probabilities.append(None)
106
- else:
107
- # Nessun sub-classificatore per questa macro, usa solo macro
108
- final_prediction = macro_label
109
- if return_probabilities:
110
- sub_probabilities.append(None)
111
-
112
- sub_predictions.append(final_prediction)
113
-
114
- result = {
115
- 'macro_predictions': [self.main_labels.get(str(idx.item()), f"unknown_{idx.item()}")
116
- for idx in main_pred],
117
- 'final_predictions': sub_predictions
118
- }
119
-
120
- if return_probabilities:
121
- result['macro_probabilities'] = main_probs
122
- result['sub_probabilities'] = sub_probabilities
123
-
124
- return result
125
-
126
- def predict(self, x: torch.Tensor):
127
- """Metodo semplificato per predizione"""
128
- return self.forward(x, return_probabilities=False)['final_predictions']
129
-
130
- @classmethod
131
- def from_pretrained(cls, repo_path: str, **kwargs):
132
- """
133
- carica la struttura dal repo di HF: aspetta
134
- - config.json
135
- - id2label_main.json
136
- - macro_to_sub.json
137
- - sub_classifiers/<name>.bin + id2label
138
- """
139
- # 1. leggi config
140
- with open(os.path.join(repo_path, "config.json")) as f:
141
- config = json.load(f)
142
-
143
- # 2. istanzia l'oggetto
144
- model = cls(**config)
145
-
146
- # 3. carica main
147
- main_sd = torch.load(os.path.join(repo_path, "main_classifier.bin"), map_location="cpu")
148
- model.main_classifier.load_state_dict(main_sd)
149
- model.main_labels = json.load(open(os.path.join(repo_path, "id2label_main.json")))
150
-
151
- # 4. carica sub
152
- model.sub_classifiers = nn.ModuleDict()
153
- model.sub_labels = {}
154
- for name in model.sub_classifier_names:
155
- bin_path = os.path.join(repo_path, "sub_classifiers", f"{name}.bin")
156
- model.sub_classifiers[name] = model._build_submodule(name) # metodo helper che crea l’istanza
157
- model.sub_classifiers[name].load_state_dict(torch.load(bin_path, map_location="cpu"))
158
- model.sub_labels[name] = json.load(open(
159
- os.path.join(repo_path, "sub_classifiers", f"{name}_id2label.json")))
160
-
161
- model.macro_to_sub = json.load(open(os.path.join(repo_path, "macro_to_sub.json")))
162
- model.eval()
163
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+ import os
5
+ from typing import Dict, Optional, Tuple
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ class UnifiedCellClassifier(nn.Module):
9
+ def __init__(self,
10
+ main_classifier_config: Dict = None,
11
+ sub_classifiers_config: Dict = None,
12
+ sub_classifier_names: list = None,
13
+ **kwargs):
14
+ """
15
+ Args:
16
+ main_classifier_config: Configurazione per il classificatore principale
17
+ sub_classifiers_config: Configurazioni per i sub-classificatori
18
+ sub_classifier_names: Lista nomi sub-classificatori
19
+ """
20
+ super().__init__()
21
+
22
+ # Salva configurazione
23
+ self.sub_classifier_names = sub_classifier_names or []
24
+ self.main_classifier_config = main_classifier_config or {}
25
+ self.sub_classifiers_config = sub_classifiers_config or {}
26
+
27
+ # Inizializza placeholder (verranno caricati in from_pretrained)
28
+ self.main_classifier = None
29
+ self.sub_classifiers = nn.ModuleDict()
30
+ self.main_labels = {}
31
+ self.sub_labels = {}
32
+
33
+ # Mapping macrocategoria -> sub-classificatore
34
+ self.macro_to_sub = self._build_default_macro_to_sub_mapping()
35
+
36
+ def _build_default_macro_to_sub_mapping(self):
37
+ """Mapping di default - può essere sovrascritto dal file macro_to_sub.json"""
38
+ return {
39
+ "0": "B_cells_classifier",
40
+ "1": "CD4plus_T_cells_classifier",
41
+ "4": "Myeloid_cells_classifier",
42
+ "5": "NK_cells_classifier",
43
+ "7": "TRAV1_2_CD8plus_T_cells",
44
+ "8": "gd_T_cells_classfier"
45
+ }
46
+
47
+ def _create_classifier_from_config(self, config: Dict):
48
+ """Crea un classificatore dalla configurazione"""
49
+ # Esempio di configurazione base - adatta secondo i tuoi modelli
50
+ input_dim = config.get('input_dim', 512)
51
+ hidden_dim = config.get('hidden_dim', 256)
52
+ num_classes = config.get('num_classes', 10)
53
+ dropout = config.get('dropout', 0.1)
54
+
55
+ return nn.Sequential(
56
+ nn.Linear(input_dim, hidden_dim),
57
+ nn.ReLU(),
58
+ nn.Dropout(dropout),
59
+ nn.Linear(hidden_dim, num_classes)
60
+ )
61
+
62
+ def forward(self, x: torch.Tensor, return_probabilities: bool = False):
63
+ """
64
+ Forward pass gerarchico
65
+
66
+ Args:
67
+ x: Input embeddings [batch_size, embedding_dim]
68
+ return_probabilities: Se True, restituisce anche le probabilità
69
+
70
+ Returns:
71
+ Dict con macro_prediction, sub_prediction, final_prediction
72
+ """
73
+ if self.main_classifier is None:
74
+ raise RuntimeError("Modello non caricato. Usa from_pretrained() per caricare il modello.")
75
+
76
+ # Classificazione principale
77
+ with torch.no_grad():
78
+ main_logits = self.main_classifier(x)
79
+ main_probs = torch.softmax(main_logits, dim=-1)
80
+ main_pred = torch.argmax(main_logits, dim=-1)
81
+
82
+ # Classificazione secondaria
83
+ batch_size = x.shape[0]
84
+ sub_predictions = []
85
+ sub_probabilities = [] if return_probabilities else None
86
+
87
+ for i in range(batch_size):
88
+ macro_idx = str(main_pred[i].item())
89
+ macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
90
+
91
+ # Controlla se esiste sub-classificatore per questa macro
92
+ if macro_idx in self.macro_to_sub:
93
+ sub_classifier_name = self.macro_to_sub[macro_idx]
94
+ if sub_classifier_name in self.sub_classifiers:
95
+ # Usa sub-classificatore
96
+ with torch.no_grad():
97
+ sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
98
+ sub_probs = torch.softmax(sub_logits, dim=-1)
99
+ sub_pred = torch.argmax(sub_logits, dim=-1)
100
+
101
+ sub_idx = str(sub_pred.item())
102
+ sub_label = self.sub_labels[sub_classifier_name].get(sub_idx, f"unknown_{sub_idx}")
103
+ final_prediction = f"{macro_label}_{sub_label}"
104
+
105
+ if return_probabilities:
106
+ sub_probabilities.append(sub_probs[0])
107
+ else:
108
+ # Sub-classificatore non trovato, usa solo macro
109
+ final_prediction = macro_label
110
+ if return_probabilities:
111
+ sub_probabilities.append(None)
112
+ else:
113
+ # Nessun sub-classificatore per questa macro, usa solo macro
114
+ final_prediction = macro_label
115
+ if return_probabilities:
116
+ sub_probabilities.append(None)
117
+
118
+ sub_predictions.append(final_prediction)
119
+
120
+ result = {
121
+ 'macro_predictions': [self.main_labels.get(str(idx.item()), f"unknown_{idx.item()}")
122
+ for idx in main_pred],
123
+ 'final_predictions': sub_predictions
124
+ }
125
+
126
+ if return_probabilities:
127
+ result['macro_probabilities'] = main_probs
128
+ result['sub_probabilities'] = sub_probabilities
129
+
130
+ return result
131
+
132
+ def predict(self, x: torch.Tensor):
133
+ """Metodo semplificato per predizione"""
134
+ return self.forward(x, return_probabilities=False)['final_predictions']
135
+
136
+ @classmethod
137
+ def from_pretrained(cls, repo_id_or_path: str, **kwargs):
138
+ """
139
+ Carica il modello da HuggingFace Hub o da path locale
140
+
141
+ Args:
142
+ repo_id_or_path: ID del repository HF o path locale
143
+ """
144
+ # Determina se è un path locale o repo HF
145
+ is_local = os.path.exists(repo_id_or_path)
146
+
147
+ def get_file_path(filename):
148
+ if is_local:
149
+ return os.path.join(repo_id_or_path, filename)
150
+ else:
151
+ return hf_hub_download(repo_id=repo_id_or_path, filename=filename)
152
+
153
+ # 1. Carica configurazione
154
+ config_path = get_file_path("config.json")
155
+ with open(config_path) as f:
156
+ config = json.load(f)
157
+
158
+ # 2. Istanzia il modello
159
+ model = cls(**config)
160
+
161
+ # 3. Carica il classificatore principale
162
+ # Crea l'architettura del main classifier
163
+ main_config = config.get('main_classifier_config', {})
164
+ model.main_classifier = model._create_classifier_from_config(main_config)
165
+
166
+ # Carica i pesi del main classifier
167
+ main_weights_path = get_file_path("main_classifier.bin")
168
+ main_state_dict = torch.load(main_weights_path, map_location="cpu")
169
+ model.main_classifier.load_state_dict(main_state_dict)
170
+
171
+ # Carica le label del main classifier
172
+ main_labels_path = get_file_path("id2label_main.json")
173
+ with open(main_labels_path) as f:
174
+ model.main_labels = json.load(f)
175
+
176
+ # 4. Carica i sub-classificatori
177
+ model.sub_classifiers = nn.ModuleDict()
178
+ model.sub_labels = {}
179
+
180
+ for sub_name in model.sub_classifier_names:
181
+ try:
182
+ # Crea l'architettura del sub-classificatore
183
+ sub_config = config.get('sub_classifiers_config', {}).get(sub_name, {})
184
+ model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
185
+
186
+ # Carica i pesi del sub-classificatore
187
+ sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}.bin")
188
+ sub_state_dict = torch.load(sub_weights_path, map_location="cpu")
189
+ model.sub_classifiers[sub_name].load_state_dict(sub_state_dict)
190
+
191
+ # Carica le label del sub-classificatore
192
+ sub_labels_path = get_file_path(f"sub_classifiers/{sub_name}_id2label.json")
193
+ with open(sub_labels_path) as f:
194
+ model.sub_labels[sub_name] = json.load(f)
195
+
196
+ except Exception as e:
197
+ print(f"Errore nel caricamento del sub-classificatore {sub_name}: {e}")
198
+ continue
199
+
200
+ # 5. Carica il mapping macro_to_sub se esiste
201
+ try:
202
+ macro_to_sub_path = get_file_path("macro_to_sub.json")
203
+ with open(macro_to_sub_path) as f:
204
+ model.macro_to_sub = json.load(f)
205
+ except:
206
+ print("File macro_to_sub.json non trovato, uso mapping di default")
207
+
208
+ model.eval()
209
+ return model
210
+
211
+ def save_pretrained(self, save_directory: str):
212
+ """
213
+ Salva il modello in formato HuggingFace
214
+ """
215
+ os.makedirs(save_directory, exist_ok=True)
216
+
217
+ # Salva configurazione
218
+ config = {
219
+ 'sub_classifier_names': self.sub_classifier_names,
220
+ 'main_classifier_config': self.main_classifier_config,
221
+ 'sub_classifiers_config': self.sub_classifiers_config
222
+ }
223
+
224
+ with open(os.path.join(save_directory, "config.json"), 'w') as f:
225
+ json.dump(config, f, indent=2)
226
+
227
+ # Salva main classifier
228
+ if self.main_classifier is not None:
229
+ torch.save(self.main_classifier.state_dict(),
230
+ os.path.join(save_directory, "main_classifier.bin"))
231
+
232
+ with open(os.path.join(save_directory, "id2label_main.json"), 'w') as f:
233
+ json.dump(self.main_labels, f, indent=2)
234
+
235
+ # Salva sub-classifiers
236
+ sub_classifiers_dir = os.path.join(save_directory, "sub_classifiers")
237
+ os.makedirs(sub_classifiers_dir, exist_ok=True)
238
+
239
+ for name, classifier in self.sub_classifiers.items():
240
+ torch.save(classifier.state_dict(),
241
+ os.path.join(sub_classifiers_dir, f"{name}.bin"))
242
+
243
+ with open(os.path.join(sub_classifiers_dir, f"{name}_id2label.json"), 'w') as f:
244
+ json.dump(self.sub_labels[name], f, indent=2)
245
+
246
+ # Salva mapping
247
+ with open(os.path.join(save_directory, "macro_to_sub.json"), 'w') as f:
248
+ json.dump(self.macro_to_sub, f, indent=2)
249
+