vincenzocivale commited on
Commit
9f27427
·
1 Parent(s): 649a4ee

Refactor from_pretrained method for robust model loading and improved error handling

Browse files
Files changed (1) hide show
  1. unified_cell_classifier.py +50 -50
unified_cell_classifier.py CHANGED
@@ -221,78 +221,78 @@ class UnifiedCellClassifier(nn.Module):
221
  """Metodo semplificato per predizione"""
222
  return self.forward(x, return_probabilities=False)['final_predictions']
223
 
 
224
  @classmethod
225
  def from_pretrained(cls, repo_id_or_path: str, **kwargs):
226
  """
227
- Carica il modello da HuggingFace Hub o da path locale
228
-
229
- Args:
230
- repo_id_or_path: ID del repository HF o path locale
231
  """
232
- # Determina se è un path locale o repo HF
233
- is_local = os.path.exists(repo_id_or_path)
234
 
235
- def get_file_path(filename):
236
- if is_local:
237
- return os.path.join(repo_id_or_path, filename)
238
- else:
239
- return hf_hub_download(repo_id=repo_id_or_path, filename=filename)
240
-
241
- # 1. Carica configurazione
242
- config_path = get_file_path("config.json")
243
  with open(config_path) as f:
244
  config = json.load(f)
245
 
246
- # 2. Istanzia il modello
247
  model = cls(**config)
248
 
249
- # 3. Carica il classificatore principale
250
- main_config = config['main_classifier_config']
251
- model.main_classifier = model._create_classifier_from_config(main_config)
 
 
 
 
252
 
253
- # Carica i pesi del main classifier
254
- main_weights_path = os.path.join(repo_id_or_path, "main_classifier/main_classifier.safetensors")
255
  main_state_dict = load_file(main_weights_path)
256
  model.main_classifier.load_state_dict(main_state_dict, strict=False)
257
 
258
- # Carica le label del main classifier
259
- main_labels_path = get_file_path("main_classifier/id2label_main.json")
260
  with open(main_labels_path) as f:
261
  model.main_labels = json.load(f)
262
-
263
- # 4. Carica i sub-classificatori
264
- model.sub_classifiers = nn.ModuleDict()
265
- model.sub_labels = {}
266
-
267
- for sub_name in model.sub_classifier_names:
268
- try:
269
- # Crea l'architettura del sub-classificatore
270
- sub_config = config['sub_classifiers_config'][sub_name]
271
- model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
272
-
273
- sub_weights_path = os.path.join(repo_id_or_path, f"sub_classifiers/{sub_name}/{sub_name}.safetensors")
274
- sub_state_dict = load_file(sub_weights_path)
275
- model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
276
-
277
- # Carica le label del sub-classificatore
278
- sub_labels_path = get_file_path(f"sub_classifiers/{sub_name}/{sub_name}_id2label.json")
279
- with open(sub_labels_path) as f:
280
- model.sub_labels[sub_name] = json.load(f)
281
 
282
- except Exception as e:
283
- print(f"Errore nel caricamento del sub-classificatore {sub_name}: {e}")
284
- continue
285
-
286
- # 5. Carica il mapping macro_to_sub se esiste
287
- try:
288
- macro_to_sub_path = get_file_path("macro_to_sub.json")
 
 
 
 
 
289
  with open(macro_to_sub_path) as f:
290
  model.macro_to_sub = json.load(f)
291
- except:
292
- print("File macro_to_sub.json non trovato, uso mapping di default")
293
 
294
  model.eval()
295
  return model
 
296
 
297
  def save_pretrained(self, save_directory: str):
298
  """
 
221
  """Metodo semplificato per predizione"""
222
  return self.forward(x, return_probabilities=False)['final_predictions']
223
 
224
+ @classmethod
225
  @classmethod
226
  def from_pretrained(cls, repo_id_or_path: str, **kwargs):
227
  """
228
+ Carica il modello da HuggingFace Hub o da un path locale in modo robusto.
 
 
 
229
  """
 
 
230
 
231
+ # 1. Ottieni un path locale unificato
232
+ if os.path.isdir(repo_id_or_path):
233
+ local_model_path = repo_id_or_path
234
+ else:
235
+ local_model_path = snapshot_download(repo_id=repo_id_or_path)
236
+
237
+ # 2. Carica la configurazione generale
238
+ config_path = os.path.join(local_model_path, "config.json")
239
  with open(config_path) as f:
240
  config = json.load(f)
241
 
242
+ # 3. Istanzia il "contenitore" del modello
243
  model = cls(**config)
244
 
245
+ # 4. Carica il classificatore principale
246
+ # --- PASSAGGIO MANCANTE INSERITO QUI ---
247
+ # a) Crea l'architettura del classificatore principale
248
+ if 'main_classifier_config' in config:
249
+ main_config = config['main_classifier_config']
250
+ # Assumo che tu abbia un metodo per creare un classificatore dalla sua config
251
+ model.main_classifier = model._create_classifier_from_config(main_config)
252
 
253
+ # b) Ora carica i pesi, perché model.main_classifier esiste
254
+ main_weights_path = os.path.join(local_model_path, "main_classifier/main_classifier.safetensors")
255
  main_state_dict = load_file(main_weights_path)
256
  model.main_classifier.load_state_dict(main_state_dict, strict=False)
257
 
258
+ main_labels_path = os.path.join(local_model_path, "main_classifier/id2label_main.json")
 
259
  with open(main_labels_path) as f:
260
  model.main_labels = json.load(f)
261
+
262
+ # 5. Carica i sub-classificatori
263
+ if 'sub_classifiers_config' in config:
264
+ for sub_name in model.sub_classifier_names:
265
+ try:
266
+ # --- PASSAGGIO MANCANTE INSERITO QUI ---
267
+ # a) Crea l'architettura del sub-classificatore
268
+ sub_config = config['sub_classifiers_config'][sub_name]
269
+ model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
270
+
271
+ # b) Ora carica i suoi pesi
272
+ sub_weights_path = os.path.join(local_model_path, f"sub_classifiers/{sub_name}/{sub_name}.safetensors")
273
+ sub_state_dict = load_file(sub_weights_path)
274
+ model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
 
 
 
 
 
275
 
276
+ # c) Carica le sue label
277
+ sub_labels_path = os.path.join(local_model_path, f"sub_classifiers/{sub_name}/{sub_name}_id2label.json")
278
+ with open(sub_labels_path) as f:
279
+ model.sub_labels[sub_name] = json.load(f)
280
+
281
+ except Exception as e:
282
+ print(f"⚠️ Avviso: impossibile caricare il sub-classificatore {sub_name}. Errore: {e}")
283
+ continue
284
+
285
+ # 6. Carica il mapping macro_to_sub se esiste
286
+ macro_to_sub_path = os.path.join(local_model_path, "macro_to_sub.json")
287
+ if os.path.exists(macro_to_sub_path):
288
  with open(macro_to_sub_path) as f:
289
  model.macro_to_sub = json.load(f)
290
+ else:
291
+ print("File macro_to_sub.json non trovato, uso mapping di default.")
292
 
293
  model.eval()
294
  return model
295
+
296
 
297
  def save_pretrained(self, save_directory: str):
298
  """