File size: 1,203 Bytes
fb24bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | from . import partial_fc
from . import fc
def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
if margin_loss_fn is None:
classifier = None
print("No margin loss function provided, classifier will not be created")
return classifier
if classifier_cfg.name == 'partial_fc':
classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn,
model_cfg, num_classes,
rank, world_size)
elif classifier_cfg.name == 'fc':
classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn,
model_cfg, num_classes,
rank, world_size)
else:
raise ValueError(f"Unknown classifier: {classifier_cfg.name}")
if classifier_cfg.start_from:
classifier.load_state_dict_from_path(classifier_cfg.start_from)
if classifier_cfg.freeze:
for param in classifier.parameters():
param.requires_grad = False
return classifier
|