|
|
|
|
|
|
|
|
|
|
|
|
| """Some utilities for backbones, in particular for windowing"""
|
|
|
| from typing import Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
| def window_partition(x, window_size):
|
| """
|
| Partition into non-overlapping windows with padding if needed.
|
| Args:
|
| x (tensor): input tokens with [B, H, W, C].
|
| window_size (int): window size.
|
| Returns:
|
| windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| (Hp, Wp): padded height and width before partition
|
| """
|
| B, H, W, C = x.shape
|
|
|
| pad_h = (window_size - H % window_size) % window_size
|
| pad_w = (window_size - W % window_size) % window_size
|
| if pad_h > 0 or pad_w > 0:
|
| x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| Hp, Wp = H + pad_h, W + pad_w
|
|
|
| x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
| return windows, (Hp, Wp)
|
|
|
|
|
| def window_unpartition(windows, window_size, pad_hw, hw):
|
| """
|
| Window unpartition into original sequences and removing padding.
|
| Args:
|
| x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| window_size (int): window size.
|
| pad_hw (Tuple): padded height and width (Hp, Wp).
|
| hw (Tuple): original height and width (H, W) before padding.
|
| Returns:
|
| x: unpartitioned sequences with [B, H, W, C].
|
| """
|
| Hp, Wp = pad_hw
|
| H, W = hw
|
| B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| x = windows.reshape(
|
| B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| )
|
| x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
|
|
| if Hp > H or Wp > W:
|
| x = x[:, :H, :W, :]
|
| return x
|
|
|
|
|
| class PatchEmbed(nn.Module):
|
| """
|
| Image to Patch Embedding.
|
| """
|
|
|
| def __init__(
|
| self,
|
| kernel_size: Tuple[int, ...] = (7, 7),
|
| stride: Tuple[int, ...] = (4, 4),
|
| padding: Tuple[int, ...] = (3, 3),
|
| in_chans: int = 3,
|
| embed_dim: int = 768,
|
| ):
|
| """
|
| Args:
|
| kernel_size (Tuple): kernel size of the projection layer.
|
| stride (Tuple): stride of the projection layer.
|
| padding (Tuple): padding size of the projection layer.
|
| in_chans (int): Number of input image channels.
|
| embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| """
|
| super().__init__()
|
| self.proj = nn.Conv2d(
|
| in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = self.proj(x)
|
|
|
| x = x.permute(0, 2, 3, 1)
|
| return x
|
|
|