| | |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | __all__ = [ |
| | "window_partition", |
| | "window_unpartition", |
| | "add_decomposed_rel_pos", |
| | "get_abs_pos", |
| | "PatchEmbed", |
| | ] |
| |
|
| |
|
| | def window_partition(x, window_size): |
| | """ |
| | Partition into non-overlapping windows with padding if needed. |
| | Args: |
| | x (tensor): input tokens with [B, H, W, C]. |
| | window_size (int): window size. |
| | |
| | Returns: |
| | windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| | (Hp, Wp): padded height and width before partition |
| | """ |
| | B, H, W, C = x.shape |
| |
|
| | pad_h = (window_size - H % window_size) % window_size |
| | pad_w = (window_size - W % window_size) % window_size |
| | if pad_h > 0 or pad_w > 0: |
| | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| | Hp, Wp = H + pad_h, W + pad_w |
| |
|
| | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) |
| | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| | return windows, (Hp, Wp) |
| |
|
| |
|
| | def window_unpartition(windows, window_size, pad_hw, hw): |
| | """ |
| | Window unpartition into original sequences and removing padding. |
| | Args: |
| | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| | window_size (int): window size. |
| | pad_hw (Tuple): padded height and width (Hp, Wp). |
| | hw (Tuple): original height and width (H, W) before padding. |
| | |
| | Returns: |
| | x: unpartitioned sequences with [B, H, W, C]. |
| | """ |
| | Hp, Wp = pad_hw |
| | H, W = hw |
| | B = windows.shape[0] // (Hp * Wp // window_size // window_size) |
| | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) |
| |
|
| | if Hp > H or Wp > W: |
| | x = x[:, :H, :W, :].contiguous() |
| | return x |
| |
|
| |
|
| | def get_rel_pos(q_size, k_size, rel_pos): |
| | """ |
| | Get relative positional embeddings according to the relative positions of |
| | query and key sizes. |
| | Args: |
| | q_size (int): size of query q. |
| | k_size (int): size of key k. |
| | rel_pos (Tensor): relative position embeddings (L, C). |
| | |
| | Returns: |
| | Extracted positional embeddings according to relative positions. |
| | """ |
| | max_rel_dist = int(2 * max(q_size, k_size) - 1) |
| | |
| | if rel_pos.shape[0] != max_rel_dist: |
| | |
| | rel_pos_resized = F.interpolate( |
| | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), |
| | size=max_rel_dist, |
| | mode="linear", |
| | ) |
| | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) |
| | else: |
| | rel_pos_resized = rel_pos |
| |
|
| | |
| | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) |
| | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) |
| | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
| |
|
| | return rel_pos_resized[relative_coords.long()] |
| |
|
| |
|
| | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): |
| | """ |
| | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. |
| | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 |
| | Args: |
| | attn (Tensor): attention map. |
| | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). |
| | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. |
| | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. |
| | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). |
| | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). |
| | |
| | Returns: |
| | attn (Tensor): attention map with added relative positional embeddings. |
| | """ |
| | q_h, q_w = q_size |
| | k_h, k_w = k_size |
| | Rh = get_rel_pos(q_h, k_h, rel_pos_h) |
| | Rw = get_rel_pos(q_w, k_w, rel_pos_w) |
| |
|
| | B, _, dim = q.shape |
| | r_q = q.reshape(B, q_h, q_w, dim) |
| | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) |
| | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) |
| |
|
| | attn = ( |
| | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] |
| | ).view(B, q_h * q_w, k_h * k_w) |
| |
|
| | return attn |
| |
|
| |
|
| | def get_abs_pos(abs_pos, has_cls_token, hw): |
| | """ |
| | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token |
| | dimension for the original embeddings. |
| | Args: |
| | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). |
| | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. |
| | hw (Tuple): size of input image tokens. |
| | |
| | Returns: |
| | Absolute positional embeddings after processing with shape (1, H, W, C) |
| | """ |
| | h, w = hw |
| | if has_cls_token: |
| | abs_pos = abs_pos[:, 1:] |
| | xy_num = abs_pos.shape[1] |
| | size = int(math.sqrt(xy_num)) |
| | assert size * size == xy_num |
| |
|
| | if size != h or size != w: |
| | new_abs_pos = F.interpolate( |
| | abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), |
| | size=(h, w), |
| | mode="bicubic", |
| | align_corners=False, |
| | ) |
| |
|
| | return new_abs_pos.permute(0, 2, 3, 1) |
| | else: |
| | return abs_pos.reshape(1, h, w, -1) |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """ |
| | Image to Patch Embedding. |
| | """ |
| |
|
| | def __init__( |
| | self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 |
| | ): |
| | """ |
| | Args: |
| | kernel_size (Tuple): kernel size of the projection layer. |
| | stride (Tuple): stride of the projection layer. |
| | padding (Tuple): padding size of the projection layer. |
| | in_chans (int): Number of input image channels. |
| | embed_dim (int): embed_dim (int): Patch embedding dimension. |
| | """ |
| | super().__init__() |
| |
|
| | self.proj = nn.Conv2d( |
| | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.proj(x) |
| | |
| | x = x.permute(0, 2, 3, 1) |
| | return x |
| |
|