|
|
| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| __all__ = [
|
| "window_partition",
|
| "window_unpartition",
|
| "add_decomposed_rel_pos",
|
| "get_abs_pos",
|
| "PatchEmbed",
|
| ]
|
|
|
|
|
| def window_partition(x, window_size):
|
| """
|
| Partition into non-overlapping windows with padding if needed.
|
| Args:
|
| x (tensor): input tokens with [B, H, W, C].
|
| window_size (int): window size.
|
|
|
| Returns:
|
| windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| (Hp, Wp): padded height and width before partition
|
| """
|
| B, H, W, C = x.shape
|
|
|
| pad_h = (window_size - H % window_size) % window_size
|
| pad_w = (window_size - W % window_size) % window_size
|
| if pad_h > 0 or pad_w > 0:
|
| x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| Hp, Wp = H + pad_h, W + pad_w
|
|
|
| x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| return windows, (Hp, Wp)
|
|
|
|
|
| def window_unpartition(windows, window_size, pad_hw, hw):
|
| """
|
| Window unpartition into original sequences and removing padding.
|
| Args:
|
| x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| window_size (int): window size.
|
| pad_hw (Tuple): padded height and width (Hp, Wp).
|
| hw (Tuple): original height and width (H, W) before padding.
|
|
|
| Returns:
|
| x: unpartitioned sequences with [B, H, W, C].
|
| """
|
| Hp, Wp = pad_hw
|
| H, W = hw
|
| B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
|
|
| if Hp > H or Wp > W:
|
| x = x[:, :H, :W, :].contiguous()
|
| return x
|
|
|
|
|
| def get_rel_pos(q_size, k_size, rel_pos):
|
| """
|
| Get relative positional embeddings according to the relative positions of
|
| query and key sizes.
|
| Args:
|
| q_size (int): size of query q.
|
| k_size (int): size of key k.
|
| rel_pos (Tensor): relative position embeddings (L, C).
|
|
|
| Returns:
|
| Extracted positional embeddings according to relative positions.
|
| """
|
| max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
|
| if rel_pos.shape[0] != max_rel_dist:
|
|
|
| rel_pos_resized = F.interpolate(
|
| rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| size=max_rel_dist,
|
| mode="linear",
|
| )
|
| rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| else:
|
| rel_pos_resized = rel_pos
|
|
|
|
|
| q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
| return rel_pos_resized[relative_coords.long()]
|
|
|
|
|
| def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
| """
|
| Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
| https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
| Args:
|
| attn (Tensor): attention map.
|
| q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
| rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
| rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
| q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
| k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
|
|
| Returns:
|
| attn (Tensor): attention map with added relative positional embeddings.
|
| """
|
| q_h, q_w = q_size
|
| k_h, k_w = k_size
|
| Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
| Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
|
| B, _, dim = q.shape
|
| r_q = q.reshape(B, q_h, q_w, dim)
|
| rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
| rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
|
| attn = (
|
| attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| ).view(B, q_h * q_w, k_h * k_w)
|
|
|
| return attn
|
|
|
|
|
| def get_abs_pos(abs_pos, has_cls_token, hw):
|
| """
|
| Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| dimension for the original embeddings.
|
| Args:
|
| abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| hw (Tuple): size of input image tokens.
|
|
|
| Returns:
|
| Absolute positional embeddings after processing with shape (1, H, W, C)
|
| """
|
| h, w = hw
|
| if has_cls_token:
|
| abs_pos = abs_pos[:, 1:]
|
| xy_num = abs_pos.shape[1]
|
| size = int(math.sqrt(xy_num))
|
| assert size * size == xy_num
|
|
|
| if size != h or size != w:
|
| new_abs_pos = F.interpolate(
|
| abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
| size=(h, w),
|
| mode="bicubic",
|
| align_corners=False,
|
| )
|
|
|
| return new_abs_pos.permute(0, 2, 3, 1)
|
| else:
|
| return abs_pos.reshape(1, h, w, -1)
|
|
|
|
|
| class PatchEmbed(nn.Module):
|
| """
|
| Image to Patch Embedding.
|
| """
|
|
|
| def __init__(
|
| self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
|
| ):
|
| """
|
| Args:
|
| kernel_size (Tuple): kernel size of the projection layer.
|
| stride (Tuple): stride of the projection layer.
|
| padding (Tuple): padding size of the projection layer.
|
| in_chans (int): Number of input image channels.
|
| embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| """
|
| super().__init__()
|
|
|
| self.proj = nn.Conv2d(
|
| in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| )
|
|
|
| def forward(self, x):
|
| x = self.proj(x)
|
|
|
| x = x.permute(0, 2, 3, 1)
|
| return x
|
|
|