| """ |
| Encoder modules: we use ViT for the encoder. |
| """ |
|
|
| from torch import nn |
| from lib.utils.misc import is_main_process |
| from lib.models.atctrack import vit as vit_module |
| from lib.models.atctrack import vit_mm as vitmm_module |
|
|
|
|
|
|
|
|
| class EncoderBase(nn.Module): |
|
|
| def __init__(self, encoder: nn.Module, train_encoder: bool, open_layers: list, num_channels: int): |
| super().__init__() |
| open_blocks = open_layers[2:] |
| open_items = open_layers[0:2] |
| for name, parameter in encoder.named_parameters(): |
|
|
| if not train_encoder: |
| freeze = True |
| for open_block in open_blocks: |
| if open_block in name: |
| freeze = False |
| if name in open_items: |
| freeze = False |
| if freeze == True: |
| parameter.requires_grad_(False) |
|
|
| self.body = encoder |
| self.num_channels = num_channels |
|
|
| def forward(self, template_list, search_list, text_src, seq): |
| xs = self.body(template_list, search_list, text_src, seq) |
| return xs |
|
|
| def forward_rgb(self, template_list, search_list): |
| xs = self.body.forward_rgb(template_list, search_list) |
| return xs |
|
|
|
|
| class Encoder(EncoderBase): |
| """ViT encoder.""" |
| def __init__(self, name: str, |
| train_encoder: bool, |
| pretrain_type: str, |
| search_size: int, |
| search_number: int, |
| template_size: int, |
| template_number: int, |
| open_layers: list, |
| cfg=None): |
| if "vitmm" in name.lower(): |
| encoder = getattr(vitmm_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type, |
| search_size=search_size, template_size=template_size, |
| search_number=search_number, template_number=template_number, |
| drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH, |
| use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT, |
| interface_type=cfg.MODEL.INTERFACE_TYPE, |
| interface_dim=cfg.MODEL.INTERFACE_DIM, |
| instruct=cfg.MODEL.ENCODER.INSTRUCT) |
| if "_base_" in name: |
| num_channels = 768 |
| elif "_large_" in name: |
| num_channels = 1024 |
| elif "_huge_" in name: |
| num_channels = 1280 |
| else: |
| num_channels = 768 |
| elif "vit" in name.lower(): |
| encoder = getattr(vit_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type, |
| search_size=search_size, template_size=template_size, |
| search_number=search_number, template_number=template_number, |
| drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH, |
| use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT |
| ) |
| if "_base_" in name: |
| num_channels = 768 |
| elif "_large_" in name: |
| num_channels = 1024 |
| elif "_huge_" in name: |
| num_channels = 1280 |
| else: |
| num_channels = 768 |
|
|
| else: |
| raise ValueError() |
| super().__init__(encoder, train_encoder, open_layers, num_channels) |
|
|
|
|
|
|
| def build_encoder(cfg): |
| train_encoder = (cfg.TRAIN.ENCODER_MULTIPLIER > 0) and (cfg.TRAIN.FREEZE_ENCODER == False) |
| encoder = Encoder(cfg.MODEL.ENCODER.TYPE, train_encoder, |
| cfg.MODEL.ENCODER.PRETRAIN_TYPE, |
| cfg.DATA.SEARCH.SIZE, cfg.DATA.SEARCH.NUMBER, |
| cfg.DATA.TEMPLATE.SIZE, cfg.DATA.TEMPLATE.NUMBER, |
| cfg.TRAIN.ENCODER_OPEN, cfg) |
| return encoder |
|
|