vincenzocivale
commited on
Commit
·
9f27427
1
Parent(s):
649a4ee
Refactor from_pretrained method for robust model loading and improved error handling
Browse files- unified_cell_classifier.py +50 -50
unified_cell_classifier.py
CHANGED
|
@@ -221,78 +221,78 @@ class UnifiedCellClassifier(nn.Module):
|
|
| 221 |
"""Metodo semplificato per predizione"""
|
| 222 |
return self.forward(x, return_probabilities=False)['final_predictions']
|
| 223 |
|
|
|
|
| 224 |
@classmethod
|
| 225 |
def from_pretrained(cls, repo_id_or_path: str, **kwargs):
|
| 226 |
"""
|
| 227 |
-
Carica il modello da HuggingFace Hub o da path locale
|
| 228 |
-
|
| 229 |
-
Args:
|
| 230 |
-
repo_id_or_path: ID del repository HF o path locale
|
| 231 |
"""
|
| 232 |
-
# Determina se è un path locale o repo HF
|
| 233 |
-
is_local = os.path.exists(repo_id_or_path)
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
#
|
| 242 |
-
config_path =
|
| 243 |
with open(config_path) as f:
|
| 244 |
config = json.load(f)
|
| 245 |
|
| 246 |
-
#
|
| 247 |
model = cls(**config)
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
main_weights_path = os.path.join(
|
| 255 |
main_state_dict = load_file(main_weights_path)
|
| 256 |
model.main_classifier.load_state_dict(main_state_dict, strict=False)
|
| 257 |
|
| 258 |
-
|
| 259 |
-
main_labels_path = get_file_path("main_classifier/id2label_main.json")
|
| 260 |
with open(main_labels_path) as f:
|
| 261 |
model.main_labels = json.load(f)
|
| 262 |
-
|
| 263 |
-
#
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
# Carica le label del sub-classificatore
|
| 278 |
-
sub_labels_path = get_file_path(f"sub_classifiers/{sub_name}/{sub_name}_id2label.json")
|
| 279 |
-
with open(sub_labels_path) as f:
|
| 280 |
-
model.sub_labels[sub_name] = json.load(f)
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
with open(macro_to_sub_path) as f:
|
| 290 |
model.macro_to_sub = json.load(f)
|
| 291 |
-
|
| 292 |
-
print("File macro_to_sub.json non trovato, uso mapping di default")
|
| 293 |
|
| 294 |
model.eval()
|
| 295 |
return model
|
|
|
|
| 296 |
|
| 297 |
def save_pretrained(self, save_directory: str):
|
| 298 |
"""
|
|
|
|
| 221 |
"""Metodo semplificato per predizione"""
|
| 222 |
return self.forward(x, return_probabilities=False)['final_predictions']
|
| 223 |
|
| 224 |
+
@classmethod
|
| 225 |
@classmethod
|
| 226 |
def from_pretrained(cls, repo_id_or_path: str, **kwargs):
|
| 227 |
"""
|
| 228 |
+
Carica il modello da HuggingFace Hub o da un path locale in modo robusto.
|
|
|
|
|
|
|
|
|
|
| 229 |
"""
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
# 1. Ottieni un path locale unificato
|
| 232 |
+
if os.path.isdir(repo_id_or_path):
|
| 233 |
+
local_model_path = repo_id_or_path
|
| 234 |
+
else:
|
| 235 |
+
local_model_path = snapshot_download(repo_id=repo_id_or_path)
|
| 236 |
+
|
| 237 |
+
# 2. Carica la configurazione generale
|
| 238 |
+
config_path = os.path.join(local_model_path, "config.json")
|
| 239 |
with open(config_path) as f:
|
| 240 |
config = json.load(f)
|
| 241 |
|
| 242 |
+
# 3. Istanzia il "contenitore" del modello
|
| 243 |
model = cls(**config)
|
| 244 |
|
| 245 |
+
# 4. Carica il classificatore principale
|
| 246 |
+
# --- PASSAGGIO MANCANTE INSERITO QUI ---
|
| 247 |
+
# a) Crea l'architettura del classificatore principale
|
| 248 |
+
if 'main_classifier_config' in config:
|
| 249 |
+
main_config = config['main_classifier_config']
|
| 250 |
+
# Assumo che tu abbia un metodo per creare un classificatore dalla sua config
|
| 251 |
+
model.main_classifier = model._create_classifier_from_config(main_config)
|
| 252 |
|
| 253 |
+
# b) Ora carica i pesi, perché model.main_classifier esiste
|
| 254 |
+
main_weights_path = os.path.join(local_model_path, "main_classifier/main_classifier.safetensors")
|
| 255 |
main_state_dict = load_file(main_weights_path)
|
| 256 |
model.main_classifier.load_state_dict(main_state_dict, strict=False)
|
| 257 |
|
| 258 |
+
main_labels_path = os.path.join(local_model_path, "main_classifier/id2label_main.json")
|
|
|
|
| 259 |
with open(main_labels_path) as f:
|
| 260 |
model.main_labels = json.load(f)
|
| 261 |
+
|
| 262 |
+
# 5. Carica i sub-classificatori
|
| 263 |
+
if 'sub_classifiers_config' in config:
|
| 264 |
+
for sub_name in model.sub_classifier_names:
|
| 265 |
+
try:
|
| 266 |
+
# --- PASSAGGIO MANCANTE INSERITO QUI ---
|
| 267 |
+
# a) Crea l'architettura del sub-classificatore
|
| 268 |
+
sub_config = config['sub_classifiers_config'][sub_name]
|
| 269 |
+
model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
|
| 270 |
+
|
| 271 |
+
# b) Ora carica i suoi pesi
|
| 272 |
+
sub_weights_path = os.path.join(local_model_path, f"sub_classifiers/{sub_name}/{sub_name}.safetensors")
|
| 273 |
+
sub_state_dict = load_file(sub_weights_path)
|
| 274 |
+
model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
# c) Carica le sue label
|
| 277 |
+
sub_labels_path = os.path.join(local_model_path, f"sub_classifiers/{sub_name}/{sub_name}_id2label.json")
|
| 278 |
+
with open(sub_labels_path) as f:
|
| 279 |
+
model.sub_labels[sub_name] = json.load(f)
|
| 280 |
+
|
| 281 |
+
except Exception as e:
|
| 282 |
+
print(f"⚠️ Avviso: impossibile caricare il sub-classificatore {sub_name}. Errore: {e}")
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
# 6. Carica il mapping macro_to_sub se esiste
|
| 286 |
+
macro_to_sub_path = os.path.join(local_model_path, "macro_to_sub.json")
|
| 287 |
+
if os.path.exists(macro_to_sub_path):
|
| 288 |
with open(macro_to_sub_path) as f:
|
| 289 |
model.macro_to_sub = json.load(f)
|
| 290 |
+
else:
|
| 291 |
+
print("File macro_to_sub.json non trovato, uso mapping di default.")
|
| 292 |
|
| 293 |
model.eval()
|
| 294 |
return model
|
| 295 |
+
|
| 296 |
|
| 297 |
def save_pretrained(self, save_directory: str):
|
| 298 |
"""
|