Yuto2007 commited on
Commit
cef8a3d
·
verified ·
1 Parent(s): 776e5ae

Update unified_cell_classifier.py

Browse files
Files changed (1) hide show
  1. unified_cell_classifier.py +4 -4
unified_cell_classifier.py CHANGED
@@ -250,12 +250,12 @@ class UnifiedCellClassifier(nn.Module):
250
  model.main_classifier = model._create_classifier_from_config(main_config)
251
 
252
  # Carica i pesi del main classifier
253
- main_weights_path = get_file_path("main_classifier.bin")
254
- main_state_dict = torch.load(main_weights_path, map_location="cpu")
255
  model.main_classifier.load_state_dict(main_state_dict)
256
 
257
  # Carica le label del main classifier
258
- main_labels_path = get_file_path("id2label_main.json")
259
  with open(main_labels_path) as f:
260
  model.main_labels = json.load(f)
261
 
@@ -271,7 +271,7 @@ class UnifiedCellClassifier(nn.Module):
271
 
272
  # Carica i pesi del sub-classificatore
273
  sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}.bin")
274
- sub_state_dict = torch.load(sub_weights_path, map_location="cpu")
275
  model.sub_classifiers[sub_name].load_state_dict(sub_state_dict)
276
 
277
  # Carica le label del sub-classificatore
 
250
  model.main_classifier = model._create_classifier_from_config(main_config)
251
 
252
  # Carica i pesi del main classifier
253
+ main_weights_path = get_file_path("main_classifier/main_classifier.bin")
254
+ main_state_dict = torch.load(main_weights_path)
255
  model.main_classifier.load_state_dict(main_state_dict)
256
 
257
  # Carica le label del main classifier
258
+ main_labels_path = get_file_path("main_classifier/id2label_main.json")
259
  with open(main_labels_path) as f:
260
  model.main_labels = json.load(f)
261
 
 
271
 
272
  # Carica i pesi del sub-classificatore
273
  sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}.bin")
274
+ sub_state_dict = torch.load(sub_weights_path)
275
  model.sub_classifiers[sub_name].load_state_dict(sub_state_dict)
276
 
277
  # Carica le label del sub-classificatore