| | import torch |
| | import torch.nn as nn |
| |
|
| | from .attention import ( |
| | single_head_full_attention, |
| | single_head_full_attention_1d, |
| | single_head_split_window_attention, |
| | single_head_split_window_attention_1d, |
| | ) |
| | from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d |
| |
|
| |
|
| | class TransformerLayer(nn.Module): |
| | def __init__( |
| | self, |
| | d_model=128, |
| | nhead=1, |
| | no_ffn=False, |
| | ffn_dim_expansion=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.dim = d_model |
| | self.nhead = nhead |
| | self.no_ffn = no_ffn |
| |
|
| | |
| | self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| | self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| | self.v_proj = nn.Linear(d_model, d_model, bias=False) |
| |
|
| | self.merge = nn.Linear(d_model, d_model, bias=False) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| |
|
| | |
| | if not self.no_ffn: |
| | in_channels = d_model * 2 |
| | self.mlp = nn.Sequential( |
| | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), |
| | nn.GELU(), |
| | nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), |
| | ) |
| |
|
| | self.norm2 = nn.LayerNorm(d_model) |
| |
|
| | def forward( |
| | self, |
| | source, |
| | target, |
| | height=None, |
| | width=None, |
| | shifted_window_attn_mask=None, |
| | shifted_window_attn_mask_1d=None, |
| | attn_type="swin", |
| | with_shift=False, |
| | attn_num_splits=None, |
| | ): |
| | |
| | query, key, value = source, target, target |
| |
|
| | |
| | is_self_attn = (query - key).abs().max() < 1e-6 |
| |
|
| | |
| | query = self.q_proj(query) |
| | key = self.k_proj(key) |
| | value = self.v_proj(value) |
| |
|
| | if attn_type == "swin" and attn_num_splits > 1: |
| | if self.nhead > 1: |
| | |
| | |
| | raise NotImplementedError |
| | else: |
| | message = single_head_split_window_attention( |
| | query, |
| | key, |
| | value, |
| | num_splits=attn_num_splits, |
| | with_shift=with_shift, |
| | h=height, |
| | w=width, |
| | attn_mask=shifted_window_attn_mask, |
| | ) |
| |
|
| | elif attn_type == "self_swin2d_cross_1d": |
| | if self.nhead > 1: |
| | raise NotImplementedError |
| | else: |
| | if is_self_attn: |
| | if attn_num_splits > 1: |
| | message = single_head_split_window_attention( |
| | query, |
| | key, |
| | value, |
| | num_splits=attn_num_splits, |
| | with_shift=with_shift, |
| | h=height, |
| | w=width, |
| | attn_mask=shifted_window_attn_mask, |
| | ) |
| | else: |
| | |
| | message = single_head_full_attention(query, key, value) |
| |
|
| | else: |
| | |
| | message = single_head_full_attention_1d( |
| | query, |
| | key, |
| | value, |
| | h=height, |
| | w=width, |
| | ) |
| |
|
| | elif attn_type == "self_swin2d_cross_swin1d": |
| | if self.nhead > 1: |
| | raise NotImplementedError |
| | else: |
| | if is_self_attn: |
| | if attn_num_splits > 1: |
| | |
| | message = single_head_split_window_attention( |
| | query, |
| | key, |
| | value, |
| | num_splits=attn_num_splits, |
| | with_shift=with_shift, |
| | h=height, |
| | w=width, |
| | attn_mask=shifted_window_attn_mask, |
| | ) |
| | else: |
| | |
| | message = single_head_full_attention(query, key, value) |
| | else: |
| | if attn_num_splits > 1: |
| | assert shifted_window_attn_mask_1d is not None |
| | |
| | message = single_head_split_window_attention_1d( |
| | query, |
| | key, |
| | value, |
| | num_splits=attn_num_splits, |
| | with_shift=with_shift, |
| | h=height, |
| | w=width, |
| | attn_mask=shifted_window_attn_mask_1d, |
| | ) |
| | else: |
| | message = single_head_full_attention_1d( |
| | query, |
| | key, |
| | value, |
| | h=height, |
| | w=width, |
| | ) |
| |
|
| | else: |
| | message = single_head_full_attention(query, key, value) |
| |
|
| | message = self.merge(message) |
| | message = self.norm1(message) |
| |
|
| | if not self.no_ffn: |
| | message = self.mlp(torch.cat([source, message], dim=-1)) |
| | message = self.norm2(message) |
| |
|
| | return source + message |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | """self attention + cross attention + FFN.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model=128, |
| | nhead=1, |
| | ffn_dim_expansion=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.self_attn = TransformerLayer( |
| | d_model=d_model, |
| | nhead=nhead, |
| | no_ffn=True, |
| | ffn_dim_expansion=ffn_dim_expansion, |
| | ) |
| |
|
| | self.cross_attn_ffn = TransformerLayer( |
| | d_model=d_model, |
| | nhead=nhead, |
| | ffn_dim_expansion=ffn_dim_expansion, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | source, |
| | target, |
| | height=None, |
| | width=None, |
| | shifted_window_attn_mask=None, |
| | shifted_window_attn_mask_1d=None, |
| | attn_type="swin", |
| | with_shift=False, |
| | attn_num_splits=None, |
| | ): |
| | |
| |
|
| | |
| | source = self.self_attn( |
| | source, |
| | source, |
| | height=height, |
| | width=width, |
| | shifted_window_attn_mask=shifted_window_attn_mask, |
| | attn_type=attn_type, |
| | with_shift=with_shift, |
| | attn_num_splits=attn_num_splits, |
| | ) |
| |
|
| | |
| | source = self.cross_attn_ffn( |
| | source, |
| | target, |
| | height=height, |
| | width=width, |
| | shifted_window_attn_mask=shifted_window_attn_mask, |
| | shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
| | attn_type=attn_type, |
| | with_shift=with_shift, |
| | attn_num_splits=attn_num_splits, |
| | ) |
| |
|
| | return source |
| |
|
| |
|
| | class FeatureTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | num_layers=6, |
| | d_model=128, |
| | nhead=1, |
| | ffn_dim_expansion=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.d_model = d_model |
| | self.nhead = nhead |
| |
|
| | self.layers = nn.ModuleList( |
| | [ |
| | TransformerBlock( |
| | d_model=d_model, |
| | nhead=nhead, |
| | ffn_dim_expansion=ffn_dim_expansion, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| |
|
| | def forward( |
| | self, |
| | feature0, |
| | feature1, |
| | attn_type="swin", |
| | attn_num_splits=None, |
| | **kwargs, |
| | ): |
| |
|
| | b, c, h, w = feature0.shape |
| | assert self.d_model == c |
| |
|
| | feature0 = feature0.flatten(-2).permute(0, 2, 1) |
| | feature1 = feature1.flatten(-2).permute(0, 2, 1) |
| |
|
| | |
| | if "swin" in attn_type and attn_num_splits > 1: |
| | |
| | window_size_h = h // attn_num_splits |
| | window_size_w = w // attn_num_splits |
| |
|
| | |
| | shifted_window_attn_mask = generate_shift_window_attn_mask( |
| | input_resolution=(h, w), |
| | window_size_h=window_size_h, |
| | window_size_w=window_size_w, |
| | shift_size_h=window_size_h // 2, |
| | shift_size_w=window_size_w // 2, |
| | device=feature0.device, |
| | ) |
| | else: |
| | shifted_window_attn_mask = None |
| |
|
| | |
| | if "swin1d" in attn_type and attn_num_splits > 1: |
| | window_size_w = w // attn_num_splits |
| |
|
| | |
| | shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( |
| | input_w=w, |
| | window_size_w=window_size_w, |
| | shift_size_w=window_size_w // 2, |
| | device=feature0.device, |
| | ) |
| | else: |
| | shifted_window_attn_mask_1d = None |
| |
|
| | |
| | concat0 = torch.cat((feature0, feature1), dim=0) |
| | concat1 = torch.cat((feature1, feature0), dim=0) |
| |
|
| | for i, layer in enumerate(self.layers): |
| | concat0 = layer( |
| | concat0, |
| | concat1, |
| | height=h, |
| | width=w, |
| | attn_type=attn_type, |
| | with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1, |
| | attn_num_splits=attn_num_splits, |
| | shifted_window_attn_mask=shifted_window_attn_mask, |
| | shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
| | ) |
| |
|
| | |
| | concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) |
| |
|
| | feature0, feature1 = concat0.chunk(chunks=2, dim=0) |
| |
|
| | |
| | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
| | feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
| |
|
| | return feature0, feature1 |
| |
|