Update unified_cell_classifier.py
Browse files
unified_cell_classifier.py
CHANGED
|
@@ -252,7 +252,7 @@ class UnifiedCellClassifier(nn.Module):
|
|
| 252 |
# Carica i pesi del main classifier
|
| 253 |
main_weights_path = get_file_path("main_classifier/main_classifier.bin")
|
| 254 |
main_state_dict = torch.load(main_weights_path)
|
| 255 |
-
model.main_classifier.load_state_dict(main_state_dict,
|
| 256 |
|
| 257 |
# Carica le label del main classifier
|
| 258 |
main_labels_path = get_file_path("main_classifier/id2label_main.json")
|
|
@@ -272,7 +272,7 @@ class UnifiedCellClassifier(nn.Module):
|
|
| 272 |
# Carica i pesi del sub-classificatore
|
| 273 |
sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}.bin")
|
| 274 |
sub_state_dict = torch.load(sub_weights_path)
|
| 275 |
-
model.sub_classifiers[sub_name].load_state_dict(sub_state_dict,
|
| 276 |
|
| 277 |
# Carica le label del sub-classificatore
|
| 278 |
sub_labels_path = get_file_path(f"sub_classifiers/{sub_name}_id2label.json")
|
|
|
|
| 252 |
# Carica i pesi del main classifier
|
| 253 |
main_weights_path = get_file_path("main_classifier/main_classifier.bin")
|
| 254 |
main_state_dict = torch.load(main_weights_path)
|
| 255 |
+
model.main_classifier.load_state_dict(main_state_dict, strict=False)
|
| 256 |
|
| 257 |
# Carica le label del main classifier
|
| 258 |
main_labels_path = get_file_path("main_classifier/id2label_main.json")
|
|
|
|
| 272 |
# Carica i pesi del sub-classificatore
|
| 273 |
sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}.bin")
|
| 274 |
sub_state_dict = torch.load(sub_weights_path)
|
| 275 |
+
model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
|
| 276 |
|
| 277 |
# Carica le label del sub-classificatore
|
| 278 |
sub_labels_path = get_file_path(f"sub_classifiers/{sub_name}_id2label.json")
|