File size: 6,899 Bytes
692054f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import torch
import torch.nn as nn
import json
import os
from typing import Dict, Optional, Tuple
class UnifiedCellClassifier(nn.Module):
def __init__(self, models_base_path: str, sub_classifier_names: list):
"""
Args:
models_base_path: Path base dove sono salvati i modelli
sub_classifier_names: Lista nomi cartelle sub-classificatori
Struttura attesa:
- main_classifier/model.pth + id2label.json
- B_cells_classifier/model.pth + id2label.json
- T_cells_classifier/model.pth + id2label.json
- ...
"""
super().__init__()
# Carica classificatore principale
main_path = os.path.join(models_base_path, "main_classifier")
self.main_classifier = torch.load(os.path.join(main_path, "complete_model.pth"), map_location='cpu', weights_only=False)
with open(os.path.join(main_path, "id2label.json")) as f:
self.main_labels = json.load(f)
# Carica sub-classificatori
self.sub_classifiers = nn.ModuleDict()
self.sub_labels = {}
for sub_name in sub_classifier_names:
sub_path = os.path.join(models_base_path, sub_name)
if os.path.exists(sub_path):
self.sub_classifiers[sub_name] = torch.load(
os.path.join(sub_path, "complete_model.pth"), map_location='cpu', weights_only=False
)
with open(os.path.join(sub_path, "id2label.json")) as f:
self.sub_labels[sub_name] = json.load(f)
# Mapping macrocategoria -> sub-classificatore (da configurare)
self.macro_to_sub = self._build_macro_to_sub_mapping()
# Imposta modalità eval
self.eval()
def _build_macro_to_sub_mapping(self):
return {
"0": "B_cells_classifier",
"1": "CD4plus_T_cells_classifier",
"4": "Myeloid_cells_classifier",
"5": "NK_cells_classifier",
"7": "TRAV1_2_CD8plus_T_cells",
"8": "gd_T_cells_classfier"
}
def forward(self, x: torch.Tensor, return_probabilities: bool = False):
"""
Forward pass gerarchico
Args:
x: Input embeddings [batch_size, embedding_dim]
return_probabilities: Se True, restituisce anche le probabilità
Returns:
Dict con macro_prediction, sub_prediction, final_prediction
"""
# Classificazione principale
with torch.no_grad():
main_logits = self.main_classifier(x)
main_probs = torch.softmax(main_logits, dim=-1)
main_pred = torch.argmax(main_logits, dim=-1)
# Classificazione secondaria
batch_size = x.shape[0]
sub_predictions = []
sub_probabilities = [] if return_probabilities else None
for i in range(batch_size):
macro_idx = str(main_pred[i].item())
macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
# Controlla se esiste sub-classificatore per questa macro
if macro_idx in self.macro_to_sub:
sub_classifier_name = self.macro_to_sub[macro_idx]
if sub_classifier_name in self.sub_classifiers:
# Usa sub-classificatore
with torch.no_grad():
sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
sub_probs = torch.softmax(sub_logits, dim=-1)
sub_pred = torch.argmax(sub_logits, dim=-1)
sub_idx = str(sub_pred.item())
sub_label = self.sub_labels[sub_classifier_name].get(sub_idx, f"unknown_{sub_idx}")
final_prediction = f"{macro_label}_{sub_label}"
if return_probabilities:
sub_probabilities.append(sub_probs[0])
else:
# Sub-classificatore non trovato, usa solo macro
final_prediction = macro_label
if return_probabilities:
sub_probabilities.append(None)
else:
# Nessun sub-classificatore per questa macro, usa solo macro
final_prediction = macro_label
if return_probabilities:
sub_probabilities.append(None)
sub_predictions.append(final_prediction)
result = {
'macro_predictions': [self.main_labels.get(str(idx.item()), f"unknown_{idx.item()}")
for idx in main_pred],
'final_predictions': sub_predictions
}
if return_probabilities:
result['macro_probabilities'] = main_probs
result['sub_probabilities'] = sub_probabilities
return result
def predict(self, x: torch.Tensor):
"""Metodo semplificato per predizione"""
return self.forward(x, return_probabilities=False)['final_predictions']
@classmethod
def from_pretrained(cls, repo_path: str, **kwargs):
"""
carica la struttura dal repo di HF: aspetta
- config.json
- id2label_main.json
- macro_to_sub.json
- sub_classifiers/<name>.bin + id2label
"""
# 1. leggi config
with open(os.path.join(repo_path, "config.json")) as f:
config = json.load(f)
# 2. istanzia l'oggetto
model = cls(**config)
# 3. carica main
main_sd = torch.load(os.path.join(repo_path, "main_classifier.bin"), map_location="cpu")
model.main_classifier.load_state_dict(main_sd)
model.main_labels = json.load(open(os.path.join(repo_path, "id2label_main.json")))
# 4. carica sub
model.sub_classifiers = nn.ModuleDict()
model.sub_labels = {}
for name in model.sub_classifier_names:
bin_path = os.path.join(repo_path, "sub_classifiers", f"{name}.bin")
model.sub_classifiers[name] = model._build_submodule(name) # metodo helper che crea l’istanza
model.sub_classifiers[name].load_state_dict(torch.load(bin_path, map_location="cpu"))
model.sub_labels[name] = json.load(open(
os.path.join(repo_path, "sub_classifiers", f"{name}_id2label.json")))
model.macro_to_sub = json.load(open(os.path.join(repo_path, "macro_to_sub.json")))
model.eval()
return model |