Yuto2007 commited on
Commit
692054f
·
1 Parent(s): cbc9ed5

Aggiunta definizione classe UnifiedCellClassifier

Browse files
Files changed (1) hide show
  1. unified_cell_classifier.py +163 -0
unified_cell_classifier.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import json
4
+ import os
5
+ from typing import Dict, Optional, Tuple
6
+
7
+ class UnifiedCellClassifier(nn.Module):
8
+ def __init__(self, models_base_path: str, sub_classifier_names: list):
9
+ """
10
+ Args:
11
+ models_base_path: Path base dove sono salvati i modelli
12
+ sub_classifier_names: Lista nomi cartelle sub-classificatori
13
+ Struttura attesa:
14
+ - main_classifier/model.pth + id2label.json
15
+ - B_cells_classifier/model.pth + id2label.json
16
+ - T_cells_classifier/model.pth + id2label.json
17
+ - ...
18
+ """
19
+ super().__init__()
20
+
21
+ # Carica classificatore principale
22
+ main_path = os.path.join(models_base_path, "main_classifier")
23
+ self.main_classifier = torch.load(os.path.join(main_path, "complete_model.pth"), map_location='cpu', weights_only=False)
24
+ with open(os.path.join(main_path, "id2label.json")) as f:
25
+ self.main_labels = json.load(f)
26
+
27
+
28
+ # Carica sub-classificatori
29
+ self.sub_classifiers = nn.ModuleDict()
30
+ self.sub_labels = {}
31
+
32
+ for sub_name in sub_classifier_names:
33
+ sub_path = os.path.join(models_base_path, sub_name)
34
+
35
+ if os.path.exists(sub_path):
36
+ self.sub_classifiers[sub_name] = torch.load(
37
+ os.path.join(sub_path, "complete_model.pth"), map_location='cpu', weights_only=False
38
+ )
39
+ with open(os.path.join(sub_path, "id2label.json")) as f:
40
+ self.sub_labels[sub_name] = json.load(f)
41
+
42
+ # Mapping macrocategoria -> sub-classificatore (da configurare)
43
+ self.macro_to_sub = self._build_macro_to_sub_mapping()
44
+
45
+ # Imposta modalità eval
46
+ self.eval()
47
+
48
+ def _build_macro_to_sub_mapping(self):
49
+
50
+ return {
51
+ "0": "B_cells_classifier",
52
+ "1": "CD4plus_T_cells_classifier",
53
+ "4": "Myeloid_cells_classifier",
54
+ "5": "NK_cells_classifier",
55
+ "7": "TRAV1_2_CD8plus_T_cells",
56
+ "8": "gd_T_cells_classfier"
57
+ }
58
+
59
+ def forward(self, x: torch.Tensor, return_probabilities: bool = False):
60
+ """
61
+ Forward pass gerarchico
62
+
63
+ Args:
64
+ x: Input embeddings [batch_size, embedding_dim]
65
+ return_probabilities: Se True, restituisce anche le probabilità
66
+
67
+ Returns:
68
+ Dict con macro_prediction, sub_prediction, final_prediction
69
+ """
70
+ # Classificazione principale
71
+ with torch.no_grad():
72
+ main_logits = self.main_classifier(x)
73
+ main_probs = torch.softmax(main_logits, dim=-1)
74
+ main_pred = torch.argmax(main_logits, dim=-1)
75
+
76
+ # Classificazione secondaria
77
+ batch_size = x.shape[0]
78
+ sub_predictions = []
79
+ sub_probabilities = [] if return_probabilities else None
80
+
81
+ for i in range(batch_size):
82
+ macro_idx = str(main_pred[i].item())
83
+ macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
84
+
85
+ # Controlla se esiste sub-classificatore per questa macro
86
+ if macro_idx in self.macro_to_sub:
87
+ sub_classifier_name = self.macro_to_sub[macro_idx]
88
+ if sub_classifier_name in self.sub_classifiers:
89
+ # Usa sub-classificatore
90
+ with torch.no_grad():
91
+ sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
92
+ sub_probs = torch.softmax(sub_logits, dim=-1)
93
+ sub_pred = torch.argmax(sub_logits, dim=-1)
94
+
95
+ sub_idx = str(sub_pred.item())
96
+ sub_label = self.sub_labels[sub_classifier_name].get(sub_idx, f"unknown_{sub_idx}")
97
+ final_prediction = f"{macro_label}_{sub_label}"
98
+
99
+ if return_probabilities:
100
+ sub_probabilities.append(sub_probs[0])
101
+ else:
102
+ # Sub-classificatore non trovato, usa solo macro
103
+ final_prediction = macro_label
104
+ if return_probabilities:
105
+ sub_probabilities.append(None)
106
+ else:
107
+ # Nessun sub-classificatore per questa macro, usa solo macro
108
+ final_prediction = macro_label
109
+ if return_probabilities:
110
+ sub_probabilities.append(None)
111
+
112
+ sub_predictions.append(final_prediction)
113
+
114
+ result = {
115
+ 'macro_predictions': [self.main_labels.get(str(idx.item()), f"unknown_{idx.item()}")
116
+ for idx in main_pred],
117
+ 'final_predictions': sub_predictions
118
+ }
119
+
120
+ if return_probabilities:
121
+ result['macro_probabilities'] = main_probs
122
+ result['sub_probabilities'] = sub_probabilities
123
+
124
+ return result
125
+
126
+ def predict(self, x: torch.Tensor):
127
+ """Metodo semplificato per predizione"""
128
+ return self.forward(x, return_probabilities=False)['final_predictions']
129
+
130
+ @classmethod
131
+ def from_pretrained(cls, repo_path: str, **kwargs):
132
+ """
133
+ carica la struttura dal repo di HF: aspetta
134
+ - config.json
135
+ - id2label_main.json
136
+ - macro_to_sub.json
137
+ - sub_classifiers/<name>.bin + id2label
138
+ """
139
+ # 1. leggi config
140
+ with open(os.path.join(repo_path, "config.json")) as f:
141
+ config = json.load(f)
142
+
143
+ # 2. istanzia l'oggetto
144
+ model = cls(**config)
145
+
146
+ # 3. carica main
147
+ main_sd = torch.load(os.path.join(repo_path, "main_classifier.bin"), map_location="cpu")
148
+ model.main_classifier.load_state_dict(main_sd)
149
+ model.main_labels = json.load(open(os.path.join(repo_path, "id2label_main.json")))
150
+
151
+ # 4. carica sub
152
+ model.sub_classifiers = nn.ModuleDict()
153
+ model.sub_labels = {}
154
+ for name in model.sub_classifier_names:
155
+ bin_path = os.path.join(repo_path, "sub_classifiers", f"{name}.bin")
156
+ model.sub_classifiers[name] = model._build_submodule(name) # metodo helper che crea l’istanza
157
+ model.sub_classifiers[name].load_state_dict(torch.load(bin_path, map_location="cpu"))
158
+ model.sub_labels[name] = json.load(open(
159
+ os.path.join(repo_path, "sub_classifiers", f"{name}_id2label.json")))
160
+
161
+ model.macro_to_sub = json.load(open(os.path.join(repo_path, "macro_to_sub.json")))
162
+ model.eval()
163
+ return model