File size: 791 Bytes
383bfb8
2a2cec1
 
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from .classifier_ops import *

classifier_list = ["FCNorm", "CosNorm", "DotProduct", "DistFC"]

       
def get_classifier(args):

    bias_flag = args.classifier_bias
    num_features = args.num_features
    num_classes = args.num_classes

    if not args.classifier in classifier_list:
        raise NotImplementedError("Unsupported Classifier: {}".format(args.classifier))

    if args.classifier == "FCNorm":
        classifier = FCNorm(num_features, num_classes)
    elif args.classifier == "CosNorm":
        classifier = CosNorm(num_features, num_classes)
    elif args.classifier == "DotProduct":
        classifier = DotProduct(num_classes, num_features, bias_flag)
    elif args.classifier == "DistFC":
        classifier = DistFC(num_features, num_classes)
 
    return classifier