| | |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | def window_partition(x, window_size): |
| | """ |
| | x: (B, H, W, C) |
| | Returns windows of shape: (num_windows*B, window_size*window_size, C) |
| | """ |
| | B, H, W, C = x.shape |
| | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| | |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
| | |
| | windows = x.view(-1, window_size * window_size, C) |
| | return windows |
| |
|
| | def window_reverse(windows, window_size, H, W): |
| | """ |
| | Reverse of window_partition. |
| | windows: (num_windows*B, window_size*window_size, C) |
| | Returns: (B, H, W, C) |
| | """ |
| | B = int(windows.shape[0] / (H * W / window_size / window_size)) |
| | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
| | x = x.view(B, H, W, -1) |
| | return x |
| |
|
| | class SwinWindowAttention(nn.Module): |
| | """ |
| | A simplified Swin-like window attention block: |
| | 1) Partition input into windows |
| | 2) Perform multi-head self-attn |
| | 3) Merge back |
| | """ |
| | def __init__(self, embed_dim, window_size, num_heads, dropout=0.0): |
| | super(SwinWindowAttention, self).__init__() |
| | self.embed_dim = embed_dim |
| | self.window_size = window_size |
| | self.num_heads = num_heads |
| | |
| | self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward(self, x): |
| | |
| | B, C, H, W = x.shape |
| | x = x.permute(0, 2, 3, 1).contiguous() |
| | |
| | |
| | pad_h = (self.window_size - H % self.window_size) % self.window_size |
| | pad_w = (self.window_size - W % self.window_size) % self.window_size |
| | if pad_h or pad_w: |
| | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| | |
| | Hp, Wp = x.shape[1], x.shape[2] |
| | |
| | windows = window_partition(x, self.window_size) |
| | |
| | attn_windows, _ = self.mha(windows, windows, windows) |
| | attn_windows = self.dropout(attn_windows) |
| | |
| | |
| | x = window_reverse(attn_windows, self.window_size, Hp, Wp) |
| | |
| | |
| | if pad_h or pad_w: |
| | x = x[:, :H, :W, :].contiguous() |
| | |
| | |
| | x = x.permute(0, 3, 1, 2).contiguous() |
| | return x |
| |
|