| import math
|
|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
| def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False):
|
|
|
| len_t = template_tokens.shape[1]
|
| len_s = search_tokens.shape[1]
|
|
|
| if mode == 'direct':
|
| merged_feature = torch.cat((template_tokens, search_tokens), dim=1)
|
| elif mode == 'template_central':
|
| central_pivot = len_s // 2
|
| first_half = search_tokens[:, :central_pivot, :]
|
| second_half = search_tokens[:, central_pivot:, :]
|
| merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1)
|
| elif mode == 'partition':
|
| feat_size_s = int(math.sqrt(len_s))
|
| feat_size_t = int(math.sqrt(len_t))
|
| window_size = math.ceil(feat_size_t / 2.)
|
|
|
| B, _, C = template_tokens.shape
|
| H = W = feat_size_t
|
| template_tokens = template_tokens.view(B, H, W, C)
|
| pad_l = pad_b = pad_r = 0
|
|
|
| pad_t = (window_size - H % window_size) % window_size
|
| template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| _, Hp, Wp, _ = template_tokens.shape
|
| template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C)
|
| template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2)
|
| _, Hc, Wc, _ = template_tokens.shape
|
| template_tokens = template_tokens.view(B, -1, C)
|
| merged_feature = torch.cat([template_tokens, search_tokens], dim=1)
|
|
|
|
|
| merged_h, merged_w = feat_size_s + Hc, feat_size_s
|
| if return_res:
|
| return merged_feature, merged_h, merged_w
|
|
|
| else:
|
| raise NotImplementedError
|
|
|
| return merged_feature
|
|
|
|
|
| def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'):
|
| if mode == 'direct':
|
| recovered_tokens = merged_tokens
|
| elif mode == 'template_central':
|
| central_pivot = len_search_token // 2
|
| len_remain = len_search_token - central_pivot
|
| len_half_and_t = central_pivot + len_template_token
|
|
|
| first_half = merged_tokens[:, :central_pivot, :]
|
| second_half = merged_tokens[:, -len_remain:, :]
|
| template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :]
|
|
|
| recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1)
|
| elif mode == 'partition':
|
| recovered_tokens = merged_tokens
|
| else:
|
| raise NotImplementedError
|
|
|
| return recovered_tokens
|
|
|
|
|
| def window_partition(x, window_size: int):
|
| """
|
| Args:
|
| x: (B, H, W, C)
|
| window_size (int): window size
|
|
|
| Returns:
|
| windows: (num_windows*B, window_size, window_size, C)
|
| """
|
| B, H, W, C = x.shape
|
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| return windows
|
|
|
|
|
| def window_reverse(windows, window_size: int, H: int, W: int):
|
| """
|
| Args:
|
| windows: (num_windows*B, window_size, window_size, C)
|
| window_size (int): Window size
|
| H (int): Height of image
|
| W (int): Width of image
|
|
|
| Returns:
|
| x: (B, H, W, C)
|
| """
|
| B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| return x
|
|
|
|
|
| '''
|
| add token transfer to feature
|
| '''
|
| def token2feature(tokens):
|
| B,L,D=tokens.shape
|
| H=W=int(L**0.5)
|
| x = tokens.permute(0, 2, 1).view(B, D, W, H).contiguous()
|
| return x
|
|
|
|
|
| '''
|
| feature2token
|
| '''
|
| def feature2token(x):
|
| B,C,W,H = x.shape
|
| L = W*H
|
| tokens = x.view(B, C, L).permute(0, 2, 1).contiguous()
|
| return tokens |