File size: 2,433 Bytes
5d2fa0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import timm
import torch
import torch.nn as nn


class PlantDiseaseModel(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.backbone_name = config.model.backbone

        self.model = timm.create_model(
            self.backbone_name,
            pretrained=config.model.pretrained,
            num_classes=num_classes,
            drop_rate=config.model.dropout,
            drop_path_rate=config.model.drop_path,
        )

        if config.model.freeze_backbone:
            self._freeze_backbone()
        if config.model.freeze_bn:
            self.freeze_bn()

    def _freeze_backbone(self):
        for param in self.model.parameters():
            param.requires_grad = False

        if hasattr(self.model, "get_classifier"):
            classifier = self.model.get_classifier()
            for param in classifier.parameters():
                param.requires_grad = True
        else:
            for name, param in self.model.named_parameters():
                if "head" in name or "classifier" in name:
                    param.requires_grad = True

    def freeze_bn(self):
        for module in self.model.modules():
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                module.eval()

                if module.weight is not None:
                    module.weight.requires_grad = False
                if module.bias is not None:
                    module.bias.requires_grad = False

    def forward(self, x):
        return self.model(x)


def get_param_groups(model, base_lr, head_lr, weight_decay):
    if hasattr(model.model, "get_classifier"):
        head = model.model.get_classifier()
        head_params = list(head.parameters())
        head_param_ids = set(id(p) for p in head_params)
    else:
        # fallback
        head_params = []
        for name, p in model.named_parameters():
            if any(k in name for k in ["head", "classifier"]):
                head_params.append(p)
        head_param_ids = set(id(p) for p in head_params)

    head_params = [p for p in head_params if p.requires_grad]

    backbone_params = [
        p for p in model.parameters() if id(p) not in head_param_ids and p.requires_grad
    ]
    return [
        {"params": backbone_params, "lr": base_lr, "weight_decay": weight_decay},
        {"params": head_params, "lr": head_lr, "weight_decay": weight_decay},
    ]