File size: 4,219 Bytes
25986db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """
Encoder modules: we use ViT for the encoder.
"""
from torch import nn
from lib.utils.misc import is_main_process
from lib.models.atctrack import vit as vit_module
from lib.models.atctrack import vit_mm as vitmm_module
class EncoderBase(nn.Module):
def __init__(self, encoder: nn.Module, train_encoder: bool, open_layers: list, num_channels: int):
super().__init__()
open_blocks = open_layers[2:]
open_items = open_layers[0:2]
for name, parameter in encoder.named_parameters():
if not train_encoder:
freeze = True
for open_block in open_blocks:
if open_block in name:
freeze = False
if name in open_items:
freeze = False
if freeze == True:
parameter.requires_grad_(False) # here should allow users to specify which layers to freeze !
self.body = encoder
self.num_channels = num_channels
def forward(self, template_list, search_list, text_src, seq):
xs = self.body(template_list, search_list, text_src, seq)
return xs
def forward_rgb(self, template_list, search_list):
xs = self.body.forward_rgb(template_list, search_list)
return xs
class Encoder(EncoderBase):
"""ViT encoder."""
def __init__(self, name: str,
train_encoder: bool,
pretrain_type: str,
search_size: int,
search_number: int,
template_size: int,
template_number: int,
open_layers: list,
cfg=None):
if "vitmm" in name.lower():
encoder = getattr(vitmm_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
search_size=search_size, template_size=template_size,
search_number=search_number, template_number=template_number,
drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT,
interface_type=cfg.MODEL.INTERFACE_TYPE,
interface_dim=cfg.MODEL.INTERFACE_DIM,
instruct=cfg.MODEL.ENCODER.INSTRUCT)
if "_base_" in name:
num_channels = 768
elif "_large_" in name:
num_channels = 1024
elif "_huge_" in name:
num_channels = 1280
else:
num_channels = 768
elif "vit" in name.lower():
encoder = getattr(vit_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
search_size=search_size, template_size=template_size,
search_number=search_number, template_number=template_number,
drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT
)
if "_base_" in name:
num_channels = 768
elif "_large_" in name:
num_channels = 1024
elif "_huge_" in name:
num_channels = 1280
else:
num_channels = 768
else:
raise ValueError()
super().__init__(encoder, train_encoder, open_layers, num_channels)
def build_encoder(cfg):
train_encoder = (cfg.TRAIN.ENCODER_MULTIPLIER > 0) and (cfg.TRAIN.FREEZE_ENCODER == False)
encoder = Encoder(cfg.MODEL.ENCODER.TYPE, train_encoder,
cfg.MODEL.ENCODER.PRETRAIN_TYPE,
cfg.DATA.SEARCH.SIZE, cfg.DATA.SEARCH.NUMBER,
cfg.DATA.TEMPLATE.SIZE, cfg.DATA.TEMPLATE.NUMBER,
cfg.TRAIN.ENCODER_OPEN, cfg)
return encoder
|