| |
| import math |
| import numpy as np |
| from scipy import interpolate |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| __all__ = [ |
| "window_partition", |
| "window_unpartition", |
| "add_decomposed_rel_pos", |
| "get_abs_pos", |
| "PatchEmbed", |
| "VisionRotaryEmbeddingFast", |
| ] |
|
|
|
|
| def window_partition(x, window_size): |
| """ |
| Partition into non-overlapping windows with padding if needed. |
| Args: |
| x (tensor): input tokens with [B, H, W, C]. |
| window_size (int): window size. |
| |
| Returns: |
| windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| (Hp, Wp): padded height and width before partition |
| """ |
| B, H, W, C = x.shape |
|
|
| pad_h = (window_size - H % window_size) % window_size |
| pad_w = (window_size - W % window_size) % window_size |
| if pad_h > 0 or pad_w > 0: |
| x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| Hp, Wp = H + pad_h, W + pad_w |
|
|
| x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) |
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| return windows, (Hp, Wp) |
|
|
|
|
| def window_unpartition(windows, window_size, pad_hw, hw): |
| """ |
| Window unpartition into original sequences and removing padding. |
| Args: |
| x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| window_size (int): window size. |
| pad_hw (Tuple): padded height and width (Hp, Wp). |
| hw (Tuple): original height and width (H, W) before padding. |
| |
| Returns: |
| x: unpartitioned sequences with [B, H, W, C]. |
| """ |
| Hp, Wp = pad_hw |
| H, W = hw |
| B = windows.shape[0] // (Hp * Wp // window_size // window_size) |
| x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) |
|
|
| if Hp > H or Wp > W: |
| x = x[:, :H, :W, :].contiguous() |
| return x |
|
|
|
|
| def get_rel_pos(q_size, k_size, rel_pos): |
| """ |
| Get relative positional embeddings according to the relative positions of |
| query and key sizes. |
| Args: |
| q_size (int): size of query q. |
| k_size (int): size of key k. |
| rel_pos (Tensor): relative position embeddings (L, C). |
| |
| Returns: |
| Extracted positional embeddings according to relative positions. |
| """ |
| max_rel_dist = int(2 * max(q_size, k_size) - 1) |
| use_log_interpolation = True |
|
|
| |
| if rel_pos.shape[0] != max_rel_dist: |
| if not use_log_interpolation: |
| |
| rel_pos_resized = F.interpolate( |
| rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), |
| size=max_rel_dist, |
| mode="linear", |
| ) |
| rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) |
| else: |
| src_size = rel_pos.shape[0] |
| dst_size = max_rel_dist |
|
|
| |
| q = 1.0903078 |
| dis = [] |
|
|
| cur = 1 |
| for i in range(src_size // 2): |
| dis.append(cur) |
| cur += q ** (i + 1) |
|
|
| r_ids = [-_ for _ in reversed(dis)] |
| x = r_ids + [0] + dis |
| t = dst_size // 2.0 |
| dx = np.arange(-t, t + 0.1, 1.0) |
| |
| |
| all_rel_pos_bias = [] |
| for i in range(rel_pos.shape[1]): |
| z = rel_pos[:, i].view(src_size).cpu().float().numpy() |
| f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") |
| all_rel_pos_bias.append( |
| torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) |
| rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) |
| else: |
| rel_pos_resized = rel_pos |
|
|
| |
| q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) |
| k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) |
| relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
|
|
| return rel_pos_resized[relative_coords.long()] |
|
|
|
|
| def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): |
| """ |
| Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. |
| https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 |
| Args: |
| attn (Tensor): attention map. |
| q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). |
| rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. |
| rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. |
| q_size (Tuple): spatial sequence size of query q with (q_h, q_w). |
| k_size (Tuple): spatial sequence size of key k with (k_h, k_w). |
| |
| Returns: |
| attn (Tensor): attention map with added relative positional embeddings. |
| """ |
| q_h, q_w = q_size |
| k_h, k_w = k_size |
| Rh = get_rel_pos(q_h, k_h, rel_pos_h) |
| Rw = get_rel_pos(q_w, k_w, rel_pos_w) |
|
|
| B, _, dim = q.shape |
| r_q = q.reshape(B, q_h, q_w, dim) |
| rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) |
| rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) |
|
|
| attn = ( |
| attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] |
| ).view(B, q_h * q_w, k_h * k_w) |
|
|
| return attn |
|
|
|
|
| def get_abs_pos(abs_pos, has_cls_token, hw): |
| """ |
| Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token |
| dimension for the original embeddings. |
| Args: |
| abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). |
| has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. |
| hw (Tuple): size of input image tokens. |
| |
| Returns: |
| Absolute positional embeddings after processing with shape (1, H, W, C) |
| """ |
| h, w = hw |
| if has_cls_token: |
| abs_pos = abs_pos[:, 1:] |
| xy_num = abs_pos.shape[1] |
| size = int(math.sqrt(xy_num)) |
| assert size * size == xy_num |
|
|
| if size != h or size != w: |
| new_abs_pos = F.interpolate( |
| abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), |
| size=(h, w), |
| mode="bicubic", |
| align_corners=False, |
| ) |
|
|
| return new_abs_pos.permute(0, 2, 3, 1) |
| else: |
| return abs_pos.reshape(1, h, w, -1) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| Image to Patch Embedding. |
| """ |
|
|
| def __init__( |
| self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 |
| ): |
| """ |
| Args: |
| kernel_size (Tuple): kernel size of the projection layer. |
| stride (Tuple): stride of the projection layer. |
| padding (Tuple): padding size of the projection layer. |
| in_chans (int): Number of input image channels. |
| embed_dim (int): embed_dim (int): Patch embedding dimension. |
| """ |
| super().__init__() |
|
|
| self.proj = nn.Conv2d( |
| in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding |
| ) |
|
|
| def forward(self, x): |
| x = self.proj(x) |
| |
| x = x.permute(0, 2, 3, 1) |
| return x |
| |
|
|
|
|
|
|
| from math import pi |
|
|
| import torch |
| from torch import nn |
|
|
| from einops import rearrange, repeat |
|
|
|
|
| def broadcat(tensors, dim = -1): |
| num_tensors = len(tensors) |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) |
| assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' |
| shape_len = list(shape_lens)[0] |
| dim = (dim + shape_len) if dim < 0 else dim |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] |
| assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) |
| expanded_dims.insert(dim, (dim, dims[dim])) |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) |
| return torch.cat(tensors, dim = dim) |
|
|
|
|
| def rotate_half(x): |
| x = rearrange(x, '... (d r) -> ... d r', r = 2) |
| x1, x2 = x.unbind(dim = -1) |
| x = torch.stack((-x2, x1), dim = -1) |
| return rearrange(x, '... d r -> ... (d r)') |
|
|
|
|
| class VisionRotaryEmbedding(nn.Module): |
| def __init__( |
| self, |
| dim, |
| pt_seq_len, |
| ft_seq_len=None, |
| custom_freqs = None, |
| freqs_for = 'lang', |
| theta = 10000, |
| max_freq = 10, |
| num_freqs = 1, |
| ): |
| super().__init__() |
| if custom_freqs: |
| freqs = custom_freqs |
| elif freqs_for == 'lang': |
| freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
| elif freqs_for == 'pixel': |
| freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi |
| elif freqs_for == 'constant': |
| freqs = torch.ones(num_freqs).float() |
| else: |
| raise ValueError(f'unknown modality {freqs_for}') |
|
|
| if ft_seq_len is None: ft_seq_len = pt_seq_len |
| t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len |
|
|
| freqs_h = torch.einsum('..., f -> ... f', t, freqs) |
| freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) |
|
|
| freqs_w = torch.einsum('..., f -> ... f', t, freqs) |
| freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) |
|
|
| freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) |
|
|
| self.register_buffer("freqs_cos", freqs.cos()) |
| self.register_buffer("freqs_sin", freqs.sin()) |
|
|
| print('======== shape of rope freq', self.freqs_cos.shape, '========') |
|
|
| def forward(self, t, start_index = 0): |
| rot_dim = self.freqs_cos.shape[-1] |
| end_index = start_index + rot_dim |
| assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' |
| t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] |
| t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) |
| return torch.cat((t_left, t, t_right), dim = -1) |
|
|
|
|
| class VisionRotaryEmbeddingFast(nn.Module): |
| def __init__( |
| self, |
| dim, |
| pt_seq_len=16, |
| ft_seq_len=None, |
| custom_freqs = None, |
| freqs_for = 'lang', |
| theta = 10000, |
| max_freq = 10, |
| num_freqs = 1, |
| real_img_size = None |
| ): |
| super().__init__() |
| if custom_freqs: |
| freqs = custom_freqs |
| elif freqs_for == 'lang': |
| freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
| elif freqs_for == 'pixel': |
| freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi |
| elif freqs_for == 'constant': |
| freqs = torch.ones(num_freqs).float() |
| else: |
| raise ValueError(f'unknown modality {freqs_for}') |
|
|
| if ft_seq_len is None: ft_seq_len = pt_seq_len |
| t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len |
|
|
| freqs = torch.einsum('..., f -> ... f', t, freqs) |
| freqs = repeat(freqs, '... n -> ... (n r)', r = 2) |
| freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) |
|
|
| freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) |
| freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) |
| |
| if real_img_size is not None: |
| new_freqs_cos = F.interpolate( |
| freqs_cos.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2), |
| size=real_img_size, |
| mode="bicubic", |
| align_corners=False, |
| ).permute(0, 2, 3, 1) |
|
|
| new_freqs_sin = F.interpolate( |
| freqs_sin.reshape(1, ft_seq_len, ft_seq_len, -1).permute(0, 3, 1, 2), |
| size=real_img_size, |
| mode="bicubic", |
| align_corners=False, |
| ).permute(0, 2, 3, 1) |
|
|
| self.register_buffer("freqs_cos", new_freqs_cos.view(-1, freqs.shape[-1])) |
| self.register_buffer("freqs_sin", new_freqs_sin.view(-1, freqs.shape[-1])) |
| else: |
| self.register_buffer("freqs_cos", freqs_cos) |
| self.register_buffer("freqs_sin", freqs_sin) |
|
|
| def forward(self, t): |
| return t * self.freqs_cos[:, None, :] + rotate_half(t) * self.freqs_sin[:, None, :] |
|
|