Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import torch
import torch.nn.functional as F
from .position import PositionEmbeddingSine
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid(
[
torch.linspace(w_min, w_max, len_w, device=device),
torch.linspace(h_min, h_max, len_h, device=device),
],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
return grid
def normalize_coords(coords, h, w):
# coords: [B, H, W, 2]
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def normalize_img(img0, img1):
# loaded images are in [0, 255]
# normalize by ImageNet mean and std
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
img0 = (img0 / 255.0 - mean) / std
img1 = (img1 / 255.0 - mean) / std
return img0, img1
def split_feature(
feature,
num_splits=2,
channel_last=False,
):
if channel_last: # [B, H, W, C]
b, h, w, c = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = (
feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
.permute(0, 1, 3, 2, 4, 5)
.reshape(b_new, h_new, w_new, c)
) # [B*K*K, H/K, W/K, C]
else: # [B, C, H, W]
b, c, h, w = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = (
feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
.permute(0, 2, 4, 1, 3, 5)
.reshape(b_new, c, h_new, w_new)
) # [B*K*K, C, H/K, W/K]
return feature
def merge_splits(
splits,
num_splits=2,
channel_last=False,
):
if channel_last: # [B*K*K, H/K, W/K, C]
b, h, w, c = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
merge = (
splits.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(new_b, num_splits * h, num_splits * w, c)
) # [B, H, W, C]
else: # [B*K*K, C, H/K, W/K]
b, c, h, w = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
merge = (
splits.permute(0, 3, 1, 4, 2, 5)
.contiguous()
.view(new_b, c, num_splits * h, num_splits * w)
) # [B, C, H, W]
return merge
def generate_shift_window_attn_mask(
input_resolution,
window_size_h,
window_size_w,
shift_size_h,
shift_size_w,
device=torch.device("cuda"),
):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# calculate attention mask for SW-MSA
h, w = input_resolution
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
h_slices = (
slice(0, -window_size_h),
slice(-window_size_h, -shift_size_h),
slice(-shift_size_h, None),
)
w_slices = (
slice(0, -window_size_w),
slice(-window_size_w, -shift_size_w),
slice(-shift_size_w, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = split_feature(
img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
)
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
if attn_splits > 1: # add position in splited window
feature0_splits = split_feature(feature0, num_splits=attn_splits)
feature1_splits = split_feature(feature1, num_splits=attn_splits)
position = pos_enc(feature0_splits)
feature0_splits = feature0_splits + position
feature1_splits = feature1_splits + position
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
else:
position = pos_enc(feature0)
feature0 = feature0 + position
feature1 = feature1 + position
return feature0, feature1
def mv_feature_add_position(features, attn_splits, feature_channels):
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
assert features.dim() == 4 # [B*V, C, H, W]
if attn_splits > 1: # add position in splited window
features_splits = split_feature(features, num_splits=attn_splits)
position = pos_enc(features_splits)
features_splits = features_splits + position
features = merge_splits(features_splits, num_splits=attn_splits)
else:
position = pos_enc(features)
features = features + position
return features