| | |
| | |
| |
|
| | |
| | |
| |
|
| | """Some utilities for backbones, in particular for windowing""" |
| |
|
| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def window_partition(x, window_size): |
| | """ |
| | Partition into non-overlapping windows with padding if needed. |
| | Args: |
| | x (tensor): input tokens with [B, H, W, C]. |
| | window_size (int): window size. |
| | Returns: |
| | windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| | (Hp, Wp): padded height and width before partition |
| | """ |
| | B, H, W, C = x.shape |
| |
|
| | pad_h = (window_size - H % window_size) % window_size |
| | pad_w = (window_size - W % window_size) % window_size |
| | if pad_h > 0 or pad_w > 0: |
| | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| | Hp, Wp = H + pad_h, W + pad_w |
| |
|
| | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) |
| | windows = ( |
| | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| | ) |
| | return windows, (Hp, Wp) |
| |
|
| |
|
| | def window_unpartition(windows, window_size, pad_hw, hw): |
| | """ |
| | Window unpartition into original sequences and removing padding. |
| | Args: |
| | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| | window_size (int): window size. |
| | pad_hw (Tuple): padded height and width (Hp, Wp). |
| | hw (Tuple): original height and width (H, W) before padding. |
| | Returns: |
| | x: unpartitioned sequences with [B, H, W, C]. |
| | """ |
| | Hp, Wp = pad_hw |
| | H, W = hw |
| | B = windows.shape[0] // (Hp * Wp // window_size // window_size) |
| | x = windows.view( |
| | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 |
| | ) |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) |
| |
|
| | if Hp > H or Wp > W: |
| | x = x[:, :H, :W, :].contiguous() |
| | return x |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """ |
| | Image to Patch Embedding. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | kernel_size: Tuple[int, ...] = (7, 7), |
| | stride: Tuple[int, ...] = (4, 4), |
| | padding: Tuple[int, ...] = (3, 3), |
| | in_chans: int = 3, |
| | embed_dim: int = 768, |
| | ): |
| | """ |
| | Args: |
| | kernel_size (Tuple): kernel size of the projection layer. |
| | stride (Tuple): stride of the projection layer. |
| | padding (Tuple): padding size of the projection layer. |
| | in_chans (int): Number of input image channels. |
| | embed_dim (int): embed_dim (int): Patch embedding dimension. |
| | """ |
| | super().__init__() |
| | self.proj = nn.Conv2d( |
| | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.proj(x) |
| | |
| | x = x.permute(0, 2, 3, 1) |
| | return x |
| |
|