from . import partial_fc from . import fc def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size): if margin_loss_fn is None: classifier = None print("No margin loss function provided, classifier will not be created") return classifier if classifier_cfg.name == 'partial_fc': classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size) elif classifier_cfg.name == 'fc': classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size) else: raise ValueError(f"Unknown classifier: {classifier_cfg.name}") if classifier_cfg.start_from: classifier.load_state_dict_from_path(classifier_cfg.start_from) if classifier_cfg.freeze: for param in classifier.parameters(): param.requires_grad = False return classifier