File size: 4,219 Bytes
25986db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Encoder modules: we use ViT for the encoder.
"""

from torch import nn
from lib.utils.misc import is_main_process
from lib.models.atctrack import vit as vit_module
from lib.models.atctrack import vit_mm as vitmm_module




class EncoderBase(nn.Module):

    def __init__(self, encoder: nn.Module, train_encoder: bool, open_layers: list, num_channels: int):
        super().__init__()
        open_blocks = open_layers[2:]
        open_items = open_layers[0:2]
        for name, parameter in encoder.named_parameters():

            if not train_encoder:
                freeze = True
                for open_block in open_blocks:
                    if open_block in name:
                        freeze = False
                if name in open_items:
                    freeze = False
                if freeze == True:
                    parameter.requires_grad_(False)  # here should allow users to specify which layers to freeze !

        self.body = encoder
        self.num_channels = num_channels

    def forward(self, template_list, search_list, text_src, seq):
        xs = self.body(template_list, search_list, text_src, seq)
        return xs

    def forward_rgb(self, template_list, search_list):
        xs = self.body.forward_rgb(template_list, search_list)
        return xs


class Encoder(EncoderBase):
    """ViT encoder."""
    def __init__(self, name: str,
                 train_encoder: bool,
                 pretrain_type: str,
                 search_size: int,
                 search_number: int,
                 template_size: int,
                 template_number: int,
                 open_layers: list,
                 cfg=None):
        if "vitmm" in name.lower():
            encoder = getattr(vitmm_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
                                                  search_size=search_size, template_size=template_size,
                                                  search_number=search_number, template_number=template_number,
                                                  drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
                                                  use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT,
                                                  interface_type=cfg.MODEL.INTERFACE_TYPE,
                                                  interface_dim=cfg.MODEL.INTERFACE_DIM,
                                                  instruct=cfg.MODEL.ENCODER.INSTRUCT)
            if "_base_" in name:
                num_channels = 768
            elif "_large_" in name:
                num_channels = 1024
            elif "_huge_" in name:
                num_channels = 1280
            else:
                num_channels = 768
        elif "vit" in name.lower():
            encoder = getattr(vit_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
                                                       search_size=search_size, template_size=template_size,
                                                       search_number=search_number, template_number=template_number,
                                                       drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
                                                       use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT
                                                      )
            if "_base_" in name:
                num_channels = 768
            elif "_large_" in name:
                num_channels = 1024
            elif "_huge_" in name:
                num_channels = 1280
            else:
                num_channels = 768

        else:
            raise ValueError()
        super().__init__(encoder, train_encoder, open_layers, num_channels)



def build_encoder(cfg):
    train_encoder = (cfg.TRAIN.ENCODER_MULTIPLIER > 0) and (cfg.TRAIN.FREEZE_ENCODER == False)
    encoder = Encoder(cfg.MODEL.ENCODER.TYPE, train_encoder,
                      cfg.MODEL.ENCODER.PRETRAIN_TYPE,
                      cfg.DATA.SEARCH.SIZE, cfg.DATA.SEARCH.NUMBER,
                      cfg.DATA.TEMPLATE.SIZE, cfg.DATA.TEMPLATE.NUMBER,
                      cfg.TRAIN.ENCODER_OPEN, cfg)
    return encoder