| import torch |
| import torch.nn as nn |
| from einops import repeat |
|
|
| from .utils import split_feature, merge_splits |
|
|
|
|
| def single_head_full_attention(q, k, v): |
| |
| assert q.dim() == k.dim() == v.dim() == 3 |
|
|
| scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) |
| attn = torch.softmax(scores, dim=2) |
| out = torch.matmul(attn, v) |
|
|
| return out |
|
|
|
|
| def generate_shift_window_attn_mask( |
| input_resolution, |
| window_size_h, |
| window_size_w, |
| shift_size_h, |
| shift_size_w, |
| device=torch.device("cuda"), |
| ): |
| |
| |
| h, w = input_resolution |
| img_mask = torch.zeros((1, h, w, 1)).to(device) |
| h_slices = ( |
| slice(0, -window_size_h), |
| slice(-window_size_h, -shift_size_h), |
| slice(-shift_size_h, None), |
| ) |
| w_slices = ( |
| slice(0, -window_size_w), |
| slice(-window_size_w, -shift_size_w), |
| slice(-shift_size_w, None), |
| ) |
| cnt = 0 |
| for h in h_slices: |
| for w in w_slices: |
| img_mask[:, h, w, :] = cnt |
| cnt += 1 |
|
|
| mask_windows = split_feature( |
| img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True |
| ) |
|
|
| mask_windows = mask_windows.view(-1, window_size_h * window_size_w) |
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
| attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
| attn_mask == 0, float(0.0) |
| ) |
|
|
| return attn_mask |
|
|
|
|
| def single_head_split_window_attention( |
| q, |
| k, |
| v, |
| num_splits=1, |
| with_shift=False, |
| h=None, |
| w=None, |
| attn_mask=None, |
| ): |
| |
| |
| |
|
|
| |
| if not (q.dim() == k.dim() == v.dim() == 3): |
| assert k.dim() == v.dim() == 4 |
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| m = k.size(1) |
|
|
| b, _, c = q.size() |
|
|
| b_new = b * num_splits * num_splits |
|
|
| window_size_h = h // num_splits |
| window_size_w = w // num_splits |
|
|
| q = q.view(b, h, w, c) |
| k = k.view(b, m, h, w, c) |
| v = v.view(b, m, h, w, c) |
|
|
| scale_factor = c**0.5 |
|
|
| if with_shift: |
| assert attn_mask is not None |
| shift_size_h = window_size_h // 2 |
| shift_size_w = window_size_w // 2 |
|
|
| q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(2, 3)) |
| v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(2, 3)) |
|
|
| q = split_feature( |
| q, num_splits=num_splits, channel_last=True |
| ) |
| k = split_feature( |
| k.permute(0, 2, 3, 4, 1).reshape(b, h, w, -1), |
| num_splits=num_splits, |
| channel_last=True, |
| ) |
| v = split_feature( |
| v.permute(0, 2, 3, 4, 1).reshape(b, h, w, -1), |
| num_splits=num_splits, |
| channel_last=True, |
| ) |
|
|
| k = ( |
| k.view(b_new, h // num_splits, w // num_splits, c, m) |
| .permute(0, 3, 1, 2, 4) |
| .reshape(b_new, c, -1) |
| ) |
| v = ( |
| v.view(b_new, h // num_splits, w // num_splits, c, m) |
| .permute(0, 1, 2, 4, 3) |
| .reshape(b_new, -1, c) |
| ) |
|
|
| scores = ( |
| torch.matmul(q.view(b_new, -1, c), k) / scale_factor |
| ) |
|
|
| if with_shift: |
| scores += attn_mask.repeat(b, 1, m) |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(attn, v) |
|
|
| out = merge_splits( |
| out.view(b_new, h // num_splits, w // num_splits, c), |
| num_splits=num_splits, |
| channel_last=True, |
| ) |
|
|
| |
| if with_shift: |
| out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) |
|
|
| out = out.view(b, -1, c) |
| else: |
| |
| assert q.dim() == k.dim() == v.dim() == 3 |
|
|
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| b, _, c = q.size() |
|
|
| b_new = b * num_splits * num_splits |
|
|
| window_size_h = h // num_splits |
| window_size_w = w // num_splits |
|
|
| q = q.view(b, h, w, c) |
| k = k.view(b, h, w, c) |
| v = v.view(b, h, w, c) |
|
|
| scale_factor = c**0.5 |
|
|
| if with_shift: |
| assert attn_mask is not None |
| shift_size_h = window_size_h // 2 |
| shift_size_w = window_size_w // 2 |
|
|
| q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
|
| q = split_feature( |
| q, num_splits=num_splits, channel_last=True |
| ) |
| k = split_feature(k, num_splits=num_splits, channel_last=True) |
| v = split_feature(v, num_splits=num_splits, channel_last=True) |
|
|
| scores = ( |
| torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) |
| / scale_factor |
| ) |
|
|
| if with_shift: |
| scores += attn_mask.repeat(b, 1, 1) |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(attn, v.view(b_new, -1, c)) |
|
|
| out = merge_splits( |
| out.view(b_new, h // num_splits, w // num_splits, c), |
| num_splits=num_splits, |
| channel_last=True, |
| ) |
|
|
| |
| if with_shift: |
| out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) |
|
|
| out = out.view(b, -1, c) |
|
|
| return out |
|
|
|
|
| def multi_head_split_window_attention( |
| q, |
| k, |
| v, |
| num_splits=1, |
| with_shift=False, |
| h=None, |
| w=None, |
| attn_mask=None, |
| num_head=1, |
| ): |
| """Multi-head scaled dot-product attention |
| Args: |
| q: [N, L, D] |
| k: [N, S, D] |
| v: [N, S, D] |
| Returns: |
| out: (N, L, D) |
| """ |
|
|
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| b, _, c = q.size() |
|
|
| b_new = b * num_splits * num_splits |
|
|
| window_size_h = h // num_splits |
| window_size_w = w // num_splits |
|
|
| q = q.view(b, h, w, c) |
| k = k.view(b, h, w, c) |
| v = v.view(b, h, w, c) |
|
|
| assert c % num_head == 0 |
|
|
| scale_factor = (c // num_head) ** 0.5 |
|
|
| if with_shift: |
| assert attn_mask is not None |
| shift_size_h = window_size_h // 2 |
| shift_size_w = window_size_w // 2 |
|
|
| q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
|
| q = split_feature(q, num_splits=num_splits) |
| k = split_feature(k, num_splits=num_splits) |
| v = split_feature(v, num_splits=num_splits) |
|
|
| |
| q = q.view(b_new, -1, num_head, c // num_head).permute(0, 2, 1, 3) |
| k = k.view(b_new, -1, num_head, c // num_head).permute(0, 2, 3, 1) |
| scores = torch.matmul(q, k) / scale_factor |
|
|
| if with_shift: |
| scores += attn_mask.unsqueeze(1).repeat(b, num_head, 1, 1) |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul( |
| attn, v.view(b_new, -1, num_head, c // num_head).permute(0, 2, 1, 3) |
| ) |
|
|
| out = merge_splits( |
| out.permute(0, 2, 1, 3).reshape(b_new, h // num_splits, w // num_splits, c), |
| num_splits=num_splits, |
| ) |
|
|
| |
| if with_shift: |
| out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) |
|
|
| out = out.view(b, -1, c) |
|
|
| return out |
|
|
|
|
| class TransformerLayer(nn.Module): |
| def __init__( |
| self, |
| d_model=256, |
| nhead=1, |
| attention_type="swin", |
| no_ffn=False, |
| ffn_dim_expansion=4, |
| with_shift=False, |
| add_per_view_attn=False, |
| **kwargs, |
| ): |
| super(TransformerLayer, self).__init__() |
|
|
| self.dim = d_model |
| self.nhead = nhead |
| self.attention_type = attention_type |
| self.no_ffn = no_ffn |
| self.add_per_view_attn = add_per_view_attn |
|
|
| self.with_shift = with_shift |
|
|
| |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.merge = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
|
|
| |
| if not self.no_ffn: |
| in_channels = d_model * 2 |
| self.mlp = nn.Sequential( |
| nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), |
| nn.GELU(), |
| nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), |
| ) |
|
|
| self.norm2 = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| source, |
| target, |
| height=None, |
| width=None, |
| shifted_window_attn_mask=None, |
| attn_num_splits=None, |
| **kwargs, |
| ): |
| if "attn_type" in kwargs: |
| attn_type = kwargs["attn_type"] |
| else: |
| attn_type = self.attention_type |
|
|
| |
| |
| query, key, value = source, target, target |
|
|
| |
| query = self.q_proj(query) |
| key = self.k_proj(key) |
| value = self.v_proj(value) |
|
|
| if attn_type == "swin" and attn_num_splits > 1: |
| if self.nhead > 1: |
| message = multi_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=self.with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| num_head=self.nhead, |
| ) |
| else: |
| if self.add_per_view_attn: |
| assert query.dim() == 3 and key.dim() == 4 and value.dim() == 4 |
| b, l, c = query.size() |
| query = query.unsqueeze(1).repeat( |
| 1, key.size(1), 1, 1 |
| ) |
| query = query.view(-1, l, c) |
| key = key.view(-1, l, c) |
| value = value.view(-1, l, c) |
| message = single_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=self.with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| ) |
| |
| message = message.view(b, -1, l, c).sum(1) |
| else: |
| message = single_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=self.with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| ) |
| else: |
| message = single_head_full_attention(query, key, value) |
|
|
| message = self.merge(message) |
| message = self.norm1(message) |
|
|
| if not self.no_ffn: |
| message = self.mlp(torch.cat([source, message], dim=-1)) |
| message = self.norm2(message) |
|
|
| return source + message |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """self attention + cross attention + FFN""" |
|
|
| def __init__( |
| self, |
| d_model=256, |
| nhead=1, |
| attention_type="swin", |
| ffn_dim_expansion=4, |
| with_shift=False, |
| add_per_view_attn=False, |
| no_cross_attn=False, |
| **kwargs, |
| ): |
| super(TransformerBlock, self).__init__() |
|
|
| self.no_cross_attn = no_cross_attn |
|
|
| if no_cross_attn: |
| self.self_attn = TransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| attention_type=attention_type, |
| ffn_dim_expansion=ffn_dim_expansion, |
| with_shift=with_shift, |
| add_per_view_attn=add_per_view_attn, |
| ) |
| else: |
| self.self_attn = TransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| attention_type=attention_type, |
| no_ffn=True, |
| ffn_dim_expansion=ffn_dim_expansion, |
| with_shift=with_shift, |
| ) |
|
|
| self.cross_attn_ffn = TransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| attention_type=attention_type, |
| ffn_dim_expansion=ffn_dim_expansion, |
| with_shift=with_shift, |
| add_per_view_attn=add_per_view_attn, |
| ) |
|
|
| def forward( |
| self, |
| source, |
| target, |
| height=None, |
| width=None, |
| shifted_window_attn_mask=None, |
| attn_num_splits=None, |
| **kwargs, |
| ): |
| |
| |
| source = self.self_attn( |
| source, |
| source, |
| height=height, |
| width=width, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| attn_num_splits=attn_num_splits, |
| **kwargs, |
| ) |
|
|
| if self.no_cross_attn: |
| return source |
|
|
| |
| source = self.cross_attn_ffn( |
| source, |
| target, |
| height=height, |
| width=width, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| attn_num_splits=attn_num_splits, |
| **kwargs, |
| ) |
|
|
| return source |
|
|
|
|
| def batch_features(features, nn_matrix=None): |
| |
| |
|
|
| |
| q = [] |
| kv = [] |
|
|
| num_views = len(features) |
| if nn_matrix is not None: |
| |
| features_tensor = torch.stack(features, dim=1) |
|
|
| for i in range(num_views): |
| x = features.copy() |
| q.append(x.pop(i)) |
|
|
| |
| if nn_matrix is not None: |
| |
| if features_tensor.dim() == 5: |
| c, h, w = features_tensor.shape[-3:] |
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v c h w", c=c, h=h, w=w) |
| elif features_tensor.dim() == 4: |
| hw, c = features_tensor.shape[-2:] |
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v hw c", hw=hw, c=c) |
|
|
| kv_x = torch.gather(features_tensor, dim=1, index=index) |
| else: |
| kv_x = torch.stack(x, dim=1) |
| kv.append(kv_x) |
|
|
| q = torch.cat(q, dim=0) |
| kv = torch.cat(kv, dim=0) |
|
|
| return q, kv |
|
|
|
|
| class MultiViewFeatureTransformer(nn.Module): |
| def __init__( |
| self, |
| num_layers=6, |
| d_model=128, |
| nhead=1, |
| attention_type="swin", |
| ffn_dim_expansion=4, |
| add_per_view_attn=False, |
| no_cross_attn=False, |
| **kwargs, |
| ): |
| super(MultiViewFeatureTransformer, self).__init__() |
|
|
| self.attention_type = attention_type |
|
|
| self.d_model = d_model |
| self.nhead = nhead |
|
|
| self.layers = nn.ModuleList( |
| [ |
| TransformerBlock( |
| d_model=d_model, |
| nhead=nhead, |
| attention_type=attention_type, |
| ffn_dim_expansion=ffn_dim_expansion, |
| with_shift=( |
| True if attention_type == "swin" and i % 2 == 1 else False |
| ), |
| add_per_view_attn=add_per_view_attn, |
| no_cross_attn=no_cross_attn, |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| |
| if num_layers > 6: |
| for i in range(6, num_layers): |
| self.layers[i].self_attn.norm1.weight.data.zero_() |
| self.layers[i].self_attn.norm1.bias.data.zero_() |
| self.layers[i].cross_attn_ffn.norm2.weight.data.zero_() |
| self.layers[i].cross_attn_ffn.norm2.bias.data.zero_() |
|
|
| def forward( |
| self, |
| multi_view_features, |
| attn_num_splits=None, |
| **kwargs, |
| ): |
| nn_matrix = kwargs.pop("nn_matrix", None) |
|
|
| |
| b, c, h, w = multi_view_features[0].shape |
| assert self.d_model == c |
|
|
| num_views = len(multi_view_features) |
|
|
| if self.attention_type == "swin" and attn_num_splits > 1: |
| |
| window_size_h = h // attn_num_splits |
| window_size_w = w // attn_num_splits |
|
|
| |
| shifted_window_attn_mask = generate_shift_window_attn_mask( |
| input_resolution=(h, w), |
| window_size_h=window_size_h, |
| window_size_w=window_size_w, |
| shift_size_h=window_size_h // 2, |
| shift_size_w=window_size_w // 2, |
| device=multi_view_features[0].device, |
| ) |
| else: |
| shifted_window_attn_mask = None |
|
|
| |
| concat0, concat1 = batch_features(multi_view_features, nn_matrix=nn_matrix) |
| concat0 = concat0.reshape(num_views * b, c, -1).permute( |
| 0, 2, 1 |
| ) |
| c1_v = num_views - 1 if nn_matrix is None else nn_matrix.shape[-1] - 1 |
| concat1 = concat1.reshape(num_views * b, c1_v, c, -1).permute( |
| 0, 1, 3, 2 |
| ) |
|
|
| for i, layer in enumerate(self.layers): |
| concat0 = layer( |
| concat0, |
| concat1, |
| height=h, |
| width=w, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| attn_num_splits=attn_num_splits, |
| ) |
|
|
| if i < len(self.layers) - 1: |
| |
| features = list(concat0.chunk(chunks=num_views, dim=0)) |
| |
| concat0, concat1 = batch_features(features, nn_matrix=nn_matrix) |
|
|
| features = concat0.chunk(chunks=num_views, dim=0) |
| features = [ |
| f.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() for f in features |
| ] |
|
|
| return features |
|
|
|
|
| def batch_features_camera_parameters( |
| features, |
| intrinsics, |
| extrinsics, |
| nn_matrix=None, |
| no_batch=False, |
| ): |
| |
| |
| |
| |
|
|
| assert ( |
| features[0].dim() == 4 and intrinsics[0].dim() == 3 and extrinsics[0].dim() == 3 |
| ) |
| assert intrinsics[0].size(-1) == intrinsics[0].size(-2) == 3 |
| assert extrinsics[0].size(-1) == extrinsics[0].size(-2) == 4 |
|
|
| |
| q = [] |
| q_intrinsics = [] |
| q_extrinsics = [] |
| kv = [] |
| kv_intrinsics = [] |
| kv_extrinsics = [] |
|
|
| num_views = len(features) |
| if nn_matrix is not None: |
| features_tensor = torch.stack(features, dim=1) |
| intrinsics_tensor = torch.stack(intrinsics, dim=1) |
| extrinsics_tensor = torch.stack(extrinsics, dim=1) |
|
|
| num_selected_views = nn_matrix.size(-1) - 1 |
| else: |
| num_selected_views = num_views - 1 |
|
|
| for i in range(num_views): |
| |
| x = features.copy() |
| q.append(x.pop(i)) |
|
|
| |
| y = intrinsics.copy() |
| q_intrinsics.append(y.pop(i)) |
| z = extrinsics.copy() |
| q_extrinsics.append(z.pop(i)) |
|
|
| |
| if nn_matrix is not None: |
| |
| if features_tensor.dim() == 5: |
| c, h, w = features_tensor.shape[-3:] |
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v c h w", c=c, h=h, w=w) |
| elif features_tensor.dim() == 4: |
| hw, c = features_tensor.shape[-2:] |
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v hw c", hw=hw, c=c) |
|
|
| kv_x = torch.gather(features_tensor, dim=1, index=index) |
|
|
| |
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v 3 3") |
| kv_y_intrinsics = torch.gather(intrinsics_tensor, dim=1, index=index) |
|
|
| index = repeat(nn_matrix[:, i, 1:], "b v -> b v 4 4") |
| kv_z_extrinsics = torch.gather(extrinsics_tensor, dim=1, index=index) |
|
|
| else: |
| kv_x = torch.stack(x, dim=1) |
| kv_y_intrinsics = torch.stack(y, dim=1) |
| kv_z_extrinsics = torch.stack(z, dim=1) |
|
|
| kv.append(kv_x) |
| kv_intrinsics.append(kv_y_intrinsics) |
| kv_extrinsics.append(kv_z_extrinsics) |
|
|
| if no_batch: |
| |
| return q, q_intrinsics, q_extrinsics, kv, kv_intrinsics, kv_extrinsics |
|
|
| c, h, w = q[0].shape[1:] |
|
|
| q = torch.stack(q, dim=1).view(-1, c, h, w) |
| q_intrinsics = torch.stack(q_intrinsics, dim=1).view(-1, 3, 3) |
| q_extrinsics = torch.stack(q_extrinsics, dim=1).view(-1, 4, 4) |
| kv = torch.stack(kv, dim=1).view( |
| -1, num_selected_views, c, h, w |
| ) |
| kv_intrinsics = torch.stack(kv_intrinsics, dim=1).view( |
| -1, num_selected_views, 3, 3 |
| ) |
| kv_extrinsics = torch.stack(kv_extrinsics, dim=1).view( |
| -1, num_selected_views, 4, 4 |
| ) |
|
|
| return q, q_intrinsics, q_extrinsics, kv, kv_intrinsics, kv_extrinsics |
|
|