File size: 6,899 Bytes
692054f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
import json
import os
from typing import Dict, Optional, Tuple

class UnifiedCellClassifier(nn.Module):
    def __init__(self, models_base_path: str, sub_classifier_names: list):
        """

        Args:

            models_base_path: Path base dove sono salvati i modelli

            sub_classifier_names: Lista nomi cartelle sub-classificatori

                Struttura attesa:

                - main_classifier/model.pth + id2label.json

                - B_cells_classifier/model.pth + id2label.json

                - T_cells_classifier/model.pth + id2label.json

                - ...

        """
        super().__init__()
        
        # Carica classificatore principale
        main_path = os.path.join(models_base_path, "main_classifier")
        self.main_classifier = torch.load(os.path.join(main_path, "complete_model.pth"), map_location='cpu', weights_only=False)
        with open(os.path.join(main_path, "id2label.json")) as f:
            self.main_labels = json.load(f)
        
        
        # Carica sub-classificatori
        self.sub_classifiers = nn.ModuleDict()
        self.sub_labels = {}
        
        for sub_name in sub_classifier_names:
            sub_path = os.path.join(models_base_path, sub_name)
            
            if os.path.exists(sub_path):
                self.sub_classifiers[sub_name] = torch.load(
                    os.path.join(sub_path, "complete_model.pth"), map_location='cpu', weights_only=False
                )
                with open(os.path.join(sub_path, "id2label.json")) as f:
                    self.sub_labels[sub_name] = json.load(f)
        
        # Mapping macrocategoria -> sub-classificatore (da configurare)
        self.macro_to_sub = self._build_macro_to_sub_mapping()
        
        # Imposta modalità eval
        self.eval()
    
    def _build_macro_to_sub_mapping(self):

        return {
            "0": "B_cells_classifier",      
            "1": "CD4plus_T_cells_classifier",     
            "4": "Myeloid_cells_classifier",     
            "5": "NK_cells_classifier",
            "7": "TRAV1_2_CD8plus_T_cells",
            "8": "gd_T_cells_classfier"
        }
    
    def forward(self, x: torch.Tensor, return_probabilities: bool = False):
        """

        Forward pass gerarchico

        

        Args:

            x: Input embeddings [batch_size, embedding_dim]

            return_probabilities: Se True, restituisce anche le probabilità

            

        Returns:

            Dict con macro_prediction, sub_prediction, final_prediction

        """
        # Classificazione principale
        with torch.no_grad():
            main_logits = self.main_classifier(x)
            main_probs = torch.softmax(main_logits, dim=-1)
            main_pred = torch.argmax(main_logits, dim=-1)
        
        # Classificazione secondaria
        batch_size = x.shape[0]
        sub_predictions = []
        sub_probabilities = [] if return_probabilities else None
        
        for i in range(batch_size):
            macro_idx = str(main_pred[i].item())
            macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
            
            # Controlla se esiste sub-classificatore per questa macro
            if macro_idx in self.macro_to_sub:
                sub_classifier_name = self.macro_to_sub[macro_idx]
                if sub_classifier_name in self.sub_classifiers:
                    # Usa sub-classificatore
                    with torch.no_grad():
                        sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
                        sub_probs = torch.softmax(sub_logits, dim=-1)
                        sub_pred = torch.argmax(sub_logits, dim=-1)
                    
                    sub_idx = str(sub_pred.item())
                    sub_label = self.sub_labels[sub_classifier_name].get(sub_idx, f"unknown_{sub_idx}")
                    final_prediction = f"{macro_label}_{sub_label}"
                    
                    if return_probabilities:
                        sub_probabilities.append(sub_probs[0])
                else:
                    # Sub-classificatore non trovato, usa solo macro
                    final_prediction = macro_label
                    if return_probabilities:
                        sub_probabilities.append(None)
            else:
                # Nessun sub-classificatore per questa macro, usa solo macro
                final_prediction = macro_label
                if return_probabilities:
                    sub_probabilities.append(None)
            
            sub_predictions.append(final_prediction)
        
        result = {
            'macro_predictions': [self.main_labels.get(str(idx.item()), f"unknown_{idx.item()}") 
                                for idx in main_pred],
            'final_predictions': sub_predictions
        }
        
        if return_probabilities:
            result['macro_probabilities'] = main_probs
            result['sub_probabilities'] = sub_probabilities
        
        return result
    
    def predict(self, x: torch.Tensor):
        """Metodo semplificato per predizione"""
        return self.forward(x, return_probabilities=False)['final_predictions']
    
    @classmethod
    def from_pretrained(cls, repo_path: str, **kwargs):
        """

        carica la struttura dal repo di HF: aspetta

          - config.json

          - id2label_main.json

          - macro_to_sub.json

          - sub_classifiers/<name>.bin + id2label

        """
        # 1. leggi config
        with open(os.path.join(repo_path, "config.json")) as f:
            config = json.load(f)

        # 2. istanzia l'oggetto
        model = cls(**config)  

        # 3. carica main
        main_sd = torch.load(os.path.join(repo_path, "main_classifier.bin"), map_location="cpu")
        model.main_classifier.load_state_dict(main_sd)
        model.main_labels = json.load(open(os.path.join(repo_path, "id2label_main.json")))

        # 4. carica sub
        model.sub_classifiers = nn.ModuleDict()
        model.sub_labels = {}
        for name in model.sub_classifier_names:
            bin_path = os.path.join(repo_path, "sub_classifiers", f"{name}.bin")
            model.sub_classifiers[name] = model._build_submodule(name)  # metodo helper che crea l’istanza
            model.sub_classifiers[name].load_state_dict(torch.load(bin_path, map_location="cpu"))
            model.sub_labels[name] = json.load(open(
                os.path.join(repo_path, "sub_classifiers", f"{name}_id2label.json")))
        
        model.macro_to_sub = json.load(open(os.path.join(repo_path, "macro_to_sub.json")))
        model.eval()
        return model