File size: 1,203 Bytes
fb24bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from . import partial_fc
from . import fc

def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):

    if margin_loss_fn is None:
        classifier = None
        print("No margin loss function provided, classifier will not be created")
        return classifier

    if classifier_cfg.name == 'partial_fc':
        classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn,
                                                                model_cfg, num_classes,
                                                                rank, world_size)
    elif classifier_cfg.name == 'fc':
        classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn,
                                                 model_cfg, num_classes,
                                                 rank, world_size)

    else:
        raise ValueError(f"Unknown classifier: {classifier_cfg.name}")

    if classifier_cfg.start_from:
        classifier.load_state_dict_from_path(classifier_cfg.start_from)

    if classifier_cfg.freeze:
        for param in classifier.parameters():
            param.requires_grad = False

    return classifier