|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from .position import PositionEmbeddingSine |
|
|
|
|
|
|
|
|
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): |
|
|
assert device is not None |
|
|
|
|
|
x, y = torch.meshgrid( |
|
|
[ |
|
|
torch.linspace(w_min, w_max, len_w, device=device), |
|
|
torch.linspace(h_min, h_max, len_h, device=device), |
|
|
], |
|
|
) |
|
|
grid = torch.stack((x, y), -1).transpose(0, 1).float() |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
def normalize_coords(coords, h, w): |
|
|
|
|
|
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) |
|
|
return (coords - c) / c |
|
|
|
|
|
|
|
|
def normalize_img(img0, img1): |
|
|
|
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) |
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) |
|
|
img0 = (img0 / 255.0 - mean) / std |
|
|
img1 = (img1 / 255.0 - mean) / std |
|
|
|
|
|
return img0, img1 |
|
|
|
|
|
|
|
|
def split_feature( |
|
|
feature, |
|
|
num_splits=2, |
|
|
channel_last=False, |
|
|
): |
|
|
if channel_last: |
|
|
b, h, w, c = feature.size() |
|
|
assert h % num_splits == 0 and w % num_splits == 0 |
|
|
|
|
|
b_new = b * num_splits * num_splits |
|
|
h_new = h // num_splits |
|
|
w_new = w // num_splits |
|
|
|
|
|
feature = ( |
|
|
feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) |
|
|
.permute(0, 1, 3, 2, 4, 5) |
|
|
.reshape(b_new, h_new, w_new, c) |
|
|
) |
|
|
else: |
|
|
b, c, h, w = feature.size() |
|
|
assert h % num_splits == 0 and w % num_splits == 0 |
|
|
|
|
|
b_new = b * num_splits * num_splits |
|
|
h_new = h // num_splits |
|
|
w_new = w // num_splits |
|
|
|
|
|
feature = ( |
|
|
feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) |
|
|
.permute(0, 2, 4, 1, 3, 5) |
|
|
.reshape(b_new, c, h_new, w_new) |
|
|
) |
|
|
|
|
|
return feature |
|
|
|
|
|
|
|
|
def merge_splits( |
|
|
splits, |
|
|
num_splits=2, |
|
|
channel_last=False, |
|
|
): |
|
|
if channel_last: |
|
|
b, h, w, c = splits.size() |
|
|
new_b = b // num_splits // num_splits |
|
|
|
|
|
splits = splits.view(new_b, num_splits, num_splits, h, w, c) |
|
|
merge = ( |
|
|
splits.permute(0, 1, 3, 2, 4, 5) |
|
|
.contiguous() |
|
|
.view(new_b, num_splits * h, num_splits * w, c) |
|
|
) |
|
|
else: |
|
|
b, c, h, w = splits.size() |
|
|
new_b = b // num_splits // num_splits |
|
|
|
|
|
splits = splits.view(new_b, num_splits, num_splits, c, h, w) |
|
|
merge = ( |
|
|
splits.permute(0, 3, 1, 4, 2, 5) |
|
|
.contiguous() |
|
|
.view(new_b, c, num_splits * h, num_splits * w) |
|
|
) |
|
|
|
|
|
return merge |
|
|
|
|
|
|
|
|
def generate_shift_window_attn_mask( |
|
|
input_resolution, |
|
|
window_size_h, |
|
|
window_size_w, |
|
|
shift_size_h, |
|
|
shift_size_w, |
|
|
device=torch.device("cuda"), |
|
|
): |
|
|
|
|
|
|
|
|
h, w = input_resolution |
|
|
img_mask = torch.zeros((1, h, w, 1)).to(device) |
|
|
h_slices = ( |
|
|
slice(0, -window_size_h), |
|
|
slice(-window_size_h, -shift_size_h), |
|
|
slice(-shift_size_h, None), |
|
|
) |
|
|
w_slices = ( |
|
|
slice(0, -window_size_w), |
|
|
slice(-window_size_w, -shift_size_w), |
|
|
slice(-shift_size_w, None), |
|
|
) |
|
|
cnt = 0 |
|
|
for h in h_slices: |
|
|
for w in w_slices: |
|
|
img_mask[:, h, w, :] = cnt |
|
|
cnt += 1 |
|
|
|
|
|
mask_windows = split_feature( |
|
|
img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True |
|
|
) |
|
|
|
|
|
mask_windows = mask_windows.view(-1, window_size_h * window_size_w) |
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
|
|
attn_mask == 0, float(0.0) |
|
|
) |
|
|
|
|
|
return attn_mask |
|
|
|
|
|
|
|
|
def feature_add_position(feature0, feature1, attn_splits, feature_channels): |
|
|
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) |
|
|
|
|
|
if attn_splits > 1: |
|
|
feature0_splits = split_feature(feature0, num_splits=attn_splits) |
|
|
feature1_splits = split_feature(feature1, num_splits=attn_splits) |
|
|
|
|
|
position = pos_enc(feature0_splits) |
|
|
|
|
|
feature0_splits = feature0_splits + position |
|
|
feature1_splits = feature1_splits + position |
|
|
|
|
|
feature0 = merge_splits(feature0_splits, num_splits=attn_splits) |
|
|
feature1 = merge_splits(feature1_splits, num_splits=attn_splits) |
|
|
else: |
|
|
position = pos_enc(feature0) |
|
|
|
|
|
feature0 = feature0 + position |
|
|
feature1 = feature1 + position |
|
|
|
|
|
return feature0, feature1 |
|
|
|
|
|
|
|
|
def mv_feature_add_position(features, attn_splits, feature_channels): |
|
|
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) |
|
|
|
|
|
assert features.dim() == 4 |
|
|
|
|
|
if attn_splits > 1: |
|
|
features_splits = split_feature(features, num_splits=attn_splits) |
|
|
position = pos_enc(features_splits) |
|
|
features_splits = features_splits + position |
|
|
features = merge_splits(features_splits, num_splits=attn_splits) |
|
|
else: |
|
|
position = pos_enc(features) |
|
|
features = features + position |
|
|
|
|
|
return features |
|
|
|