ATCTrack-VLM / lib /models /atctrack /encoder.py
SunXiang2025's picture
Upload ATCTrack-VLM code and selected checkpoints
25986db verified
"""
Encoder modules: we use ViT for the encoder.
"""
from torch import nn
from lib.utils.misc import is_main_process
from lib.models.atctrack import vit as vit_module
from lib.models.atctrack import vit_mm as vitmm_module
class EncoderBase(nn.Module):
def __init__(self, encoder: nn.Module, train_encoder: bool, open_layers: list, num_channels: int):
super().__init__()
open_blocks = open_layers[2:]
open_items = open_layers[0:2]
for name, parameter in encoder.named_parameters():
if not train_encoder:
freeze = True
for open_block in open_blocks:
if open_block in name:
freeze = False
if name in open_items:
freeze = False
if freeze == True:
parameter.requires_grad_(False) # here should allow users to specify which layers to freeze !
self.body = encoder
self.num_channels = num_channels
def forward(self, template_list, search_list, text_src, seq):
xs = self.body(template_list, search_list, text_src, seq)
return xs
def forward_rgb(self, template_list, search_list):
xs = self.body.forward_rgb(template_list, search_list)
return xs
class Encoder(EncoderBase):
"""ViT encoder."""
def __init__(self, name: str,
train_encoder: bool,
pretrain_type: str,
search_size: int,
search_number: int,
template_size: int,
template_number: int,
open_layers: list,
cfg=None):
if "vitmm" in name.lower():
encoder = getattr(vitmm_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
search_size=search_size, template_size=template_size,
search_number=search_number, template_number=template_number,
drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT,
interface_type=cfg.MODEL.INTERFACE_TYPE,
interface_dim=cfg.MODEL.INTERFACE_DIM,
instruct=cfg.MODEL.ENCODER.INSTRUCT)
if "_base_" in name:
num_channels = 768
elif "_large_" in name:
num_channels = 1024
elif "_huge_" in name:
num_channels = 1280
else:
num_channels = 768
elif "vit" in name.lower():
encoder = getattr(vit_module, name)(pretrained=is_main_process(), pretrain_type=pretrain_type,
search_size=search_size, template_size=template_size,
search_number=search_number, template_number=template_number,
drop_path_rate=cfg.MODEL.ENCODER.DROP_PATH,
use_checkpoint=cfg.MODEL.ENCODER.USE_CHECKPOINT
)
if "_base_" in name:
num_channels = 768
elif "_large_" in name:
num_channels = 1024
elif "_huge_" in name:
num_channels = 1280
else:
num_channels = 768
else:
raise ValueError()
super().__init__(encoder, train_encoder, open_layers, num_channels)
def build_encoder(cfg):
train_encoder = (cfg.TRAIN.ENCODER_MULTIPLIER > 0) and (cfg.TRAIN.FREEZE_ENCODER == False)
encoder = Encoder(cfg.MODEL.ENCODER.TYPE, train_encoder,
cfg.MODEL.ENCODER.PRETRAIN_TYPE,
cfg.DATA.SEARCH.SIZE, cfg.DATA.SEARCH.NUMBER,
cfg.DATA.TEMPLATE.SIZE, cfg.DATA.TEMPLATE.NUMBER,
cfg.TRAIN.ENCODER_OPEN, cfg)
return encoder