File size: 1,706 Bytes
d68fa90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from torchvision import datasets, models, transforms

def set_parameter_requires_grad(model, n_resblock_finetune):
    assert n_resblock_finetune in (0, 1, 2, 3, 4, 5)
    for param in model.parameters():
        param.requires_grad = False

    for name, param in model.named_parameters():
        condition = (n_resblock_finetune >= 1 and 'layer4' in name) or (n_resblock_finetune >= 2 and 'layer3' in name) or \
                    (n_resblock_finetune >= 3 and 'layer2' in name) or (n_resblock_finetune >= 4 and 'layer1' in name) or \
                    (n_resblock_finetune >= 5)

        if condition:
            param.requires_grad = True

    for name, param in model.named_parameters():
        if 'bn' in name:
            param.requires_grad = False

def initialize_model(model_name, n_resblock_finetune, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet18":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, n_resblock_finetune)
        feature_size = model_ft.fc.in_features
        input_size = 224
    elif model_name == 'resnet34':
        model_ft = models.resnet34(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, n_resblock_finetune)
        feature_size = model_ft.fc.in_features
        input_size = 224
    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size, feature_size