Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 NVIDIA CORPORATION. Licensed under the MIT license. | |
| # Written by Chen Zhu during an internship at NVIDIA, zhuchen.eric@gmail.com | |
| import math | |
| from torch import nn | |
| import torch | |
| from timm.models.layers import trunc_normal_ | |
| import torch.nn.functional as F | |
| class AttentionLS(nn.Module): | |
| """Implementation for long-short term attention. | |
| Flexible options for using window attention, global token and dynamic projection. | |
| Args: | |
| dim: input and output feature dimension. | |
| num_heads: number of attention heads. | |
| qkv_bias: whether to use bias for the projection of query, key and values. | |
| qk_scale: scale factor on query and key for numerical stability. | |
| By default, set to square root of head dimensions. | |
| attn_drop: dropout probability for attention matrix. | |
| proj_drop: dropout probability for the final output. | |
| rpe: whether to use relative position encoding. | |
| nglo: number of global tokens (e.g., CLS). | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| rpe=False, | |
| nglo=1, | |
| dp_rank=2, | |
| w=2, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.nglo = nglo | |
| # Equals to segment size (w) in the paper. | |
| self.window_size = w | |
| # Equals to r in the paper. | |
| self.dp_rank = dp_rank | |
| if self.dp_rank > 0: | |
| self.to_dynamic_projection = nn.Linear(dim, dp_rank * num_heads) | |
| # The LN of DualLN corresponding to dynamic projection | |
| self.dual_ln_dp = nn.LayerNorm(dim) | |
| # The LN of DualLN corresponding to all the tokens | |
| self.dual_ln_full = nn.LayerNorm(dim) | |
| # Adapted from ViL: https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/longformer2d.py#L55-L100 | |
| # We only add RPE to window attention. | |
| # Unnecessary to add bias for global tokens, since DualLN already adds biases. | |
| self.rpe = rpe | |
| if rpe: | |
| # handle the boarder conditions... | |
| w_pad = int(w * 0.5) | |
| self.local_relative_position_bias_table = nn.Parameter( | |
| torch.zeros(2 * (w + w_pad - 1) * (2 * w_pad + w + 1) + 1, num_heads) | |
| ) | |
| trunc_normal_(self.local_relative_position_bias_table, std=0.02) | |
| # get pair-wise relative position index | |
| coords_h = torch.arange(-w_pad, w_pad + w) | |
| coords_w = torch.arange(-w_pad, w_pad + w) | |
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 2w, 2w | |
| coords = ( | |
| coords.view(2, (w + w_pad * 2) ** 2).transpose(0, 1).unsqueeze(0) | |
| ) # 1, 4w**2, 2 | |
| q_coords_hw = torch.arange(0, w) | |
| q_coords = torch.stack( | |
| torch.meshgrid([q_coords_hw, q_coords_hw]) | |
| ) # 2, w, w | |
| q_coords = q_coords.view(2, w**2).transpose(0, 1).unsqueeze(1) # w**2, 1, 2 | |
| relative_coords = q_coords - coords | |
| relative_coords += w_pad + w - 1 # shift to start from 0 | |
| relative_coords[:, :, 0] *= 2 * w_pad + w | |
| relative_position_index = relative_coords.sum(-1) # w^2, 4w^2 | |
| self.register_buffer("relative_position_index", relative_position_index) | |
| def forward(self, x, nx=None, ny=None): | |
| B, N, C = x.shape | |
| N_feat = N - self.nglo | |
| self.img_size = int(math.sqrt(N)) if nx is None else nx | |
| qkv = self.qkv(x) | |
| # query, key, value | |
| q, k, v = qkv.chunk(3, dim=2) | |
| q = q.mul(self.scale) | |
| # Layer norm on the projected keys and values | |
| k = self.dual_ln_full(k) | |
| v = self.dual_ln_full(v) | |
| # output size: bsz x n_heads x seqlen x d | |
| if self.nglo > 0: | |
| q_cls, q = q[:, : self.nglo], q[:, self.nglo :] | |
| k_cls, k = k[:, : self.nglo], k[:, self.nglo :] | |
| v_cls, v = v[:, : self.nglo], v[:, self.nglo :] | |
| q_cls = q_cls.reshape( | |
| B, self.nglo, self.num_heads, C // self.num_heads | |
| ).transpose(1, 2) | |
| k_cls = k_cls.reshape( | |
| B, self.nglo, self.num_heads, C // self.num_heads | |
| ).transpose(1, 2) | |
| v_cls = v_cls.reshape( | |
| B, self.nglo, self.num_heads, C // self.num_heads | |
| ).transpose(1, 2) | |
| q = q.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
| k = k.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
| v = v.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
| # Long-range Attention (Dynamic Projection) | |
| if self.dp_rank > 0: | |
| # b x h x r x (l w) | |
| # Compute the projection matrix (P_i in the paper) | |
| c_scores = ( | |
| self.to_dynamic_projection(x[:, self.nglo :]) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(B, self.num_heads, self.dp_rank, -1) | |
| ) | |
| # c_scores = c_scores.softmax(dim=-1, dtype=torch.float32).to(x) | |
| c_scores = c_scores.softmax(dim=-1).to( | |
| x | |
| ) # Changed when experimenting with mixed precision (Johannes S.) | |
| # b x h x r x d | |
| k_lms = c_scores.matmul(k) | |
| k_lms = k_lms.transpose(1, 2).contiguous().view(B, self.dp_rank, -1) | |
| k_lms = ( | |
| self.dual_ln_dp(k_lms) | |
| .view(B, self.dp_rank, self.num_heads, -1) | |
| .contiguous() | |
| .permute(0, 2, 3, 1) | |
| ) | |
| # b x h x (lw) x r | |
| dots_all = q.matmul(k_lms) | |
| if self.window_size > 0: | |
| # Switch the order of dimensions if using window attention. | |
| dots_all = self.group_dots(dots_all) | |
| else: | |
| dots_all = None | |
| # Short-term Attention (Window Attention) | |
| # In our window attention, each token attends to at most (4w^2) tokens. | |
| if self.window_size > 0: | |
| dots_win = self.compute_window_scores(q, k) | |
| w2 = int(self.window_size * self.window_size) | |
| if self.rpe: | |
| w_pad = int(0.5 * self.window_size) | |
| local_relative_position_bias = self.local_relative_position_bias_table[ | |
| self.relative_position_index.view(-1) | |
| ].view( | |
| 1, w2, (w_pad * 2 + self.window_size) ** 2, -1 | |
| ) # w^2, kv_nums,H | |
| local_relative_position_bias = ( | |
| local_relative_position_bias.permute(0, 3, 1, 2) | |
| .expand(B, -1, -1, -1) | |
| .unsqueeze(2) | |
| .unsqueeze(2) | |
| ) | |
| dots_win += local_relative_position_bias | |
| if dots_all is None: | |
| dots_all = dots_win | |
| else: | |
| dots_all = torch.cat([dots_all, dots_win], dim=-1) | |
| # Global token. | |
| if self.nglo > 0: | |
| # and compute the scores of queries on CLS | |
| dots_q_cls = q.matmul(k_cls.transpose(-1, -2)) | |
| if self.window_size > 0: | |
| dots_q_cls = self.group_dots(dots_q_cls) | |
| dots_all = torch.cat([dots_all, dots_q_cls], dim=-1) | |
| # attn = dots_all.softmax(dim=-1, dtype=torch.float32).to(x) | |
| attn = dots_all.softmax(dim=-1).to( | |
| x | |
| ) # Changed when experimenting with mixed precision (Johannes S.) | |
| attn = self.attn_drop(attn) | |
| out = 0 | |
| if self.window_size > 0: | |
| offset = max(0, self.dp_rank) | |
| kv_group_size = self.window_size | |
| total_win_size = max(1, self.window_size // 2) * 2 + kv_group_size | |
| attn_win = attn[:, :, :, :, :, offset : offset + total_win_size**2] | |
| out += self.compute_window_pv(attn_win, v) | |
| attn = self.ungroup_dots(attn) | |
| # attn will be b x h x lw x n_k from now on | |
| if self.dp_rank > 0: | |
| attn_lm = attn[:, :, :, : self.dp_rank] | |
| v_lms = ( | |
| # c_scores.matmul(v.float()) | |
| c_scores.matmul( | |
| v | |
| ) # Changed when experimenting with mixed precision (Johannes S.) | |
| .to(v) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(B, self.dp_rank, -1) | |
| ) | |
| v_lms = ( | |
| self.dual_ln_dp(v_lms) | |
| .view(B, self.dp_rank, self.num_heads, -1) | |
| .contiguous() | |
| .transpose(1, 2) | |
| ) | |
| out += attn_lm.matmul(v_lms) | |
| if self.nglo > 0: | |
| attn_cls = attn[:, :, :, -self.nglo :] | |
| out += attn_cls.matmul( | |
| v_cls | |
| ) # Changed. Was `.mul` instead of `.matmul`. (JWS) | |
| # b x h x 1 x lw | |
| cls_inner = q_cls.matmul(k_cls.transpose(-1, -2)) | |
| cls_dots = q_cls.matmul( | |
| k.transpose(-1, -2) | |
| ) # Changed. Was `out` instead of `k`. (JWS) | |
| cls_dots = torch.cat([cls_inner, cls_dots], dim=-1) | |
| # cls_dots = cls_dots.softmax(dim=-1, dtype=torch.float32).to(x) | |
| cls_dots = cls_dots.softmax(dim=-1).to( | |
| x | |
| ) # Changed when experimenting with mixed precision (Johannes S.) | |
| cls_next = cls_dots[:, :, :, self.nglo :].matmul( | |
| v | |
| ) # the post_cls variant # Changed. Was `out` instead of `v`. (JWS) | |
| cls_next += cls_dots[:, :, :, : self.nglo].matmul(v_cls) | |
| out = torch.cat([cls_next, out], dim=2) | |
| out = out.transpose(1, 2).contiguous().view(B, N, -1) | |
| # x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| out = self.proj(out) | |
| out = self.proj_drop(out) | |
| return out | |
| def compute_window_scores(self, q, k): | |
| """Compute the inner products for the window attention. | |
| Frist, divide the query into non-overlapping windows. | |
| Then, use torch.as_trided (implemented in self.get_overlapping_tiles) to create a view of the keys | |
| that corresponds to the windows with at most 2x memory overhead. | |
| Finally, compute the inner product. | |
| """ | |
| # q: b h (l w) d | |
| b, h, _, d = q.shape | |
| side_size = max(self.window_size // 2, 1) | |
| # q_group_size: segment size | |
| kv_width = 2 * side_size + self.window_size # assuming q_stride=1 | |
| q_n_group = self.img_size // self.window_size | |
| q_tiles = q.reshape( | |
| b, h, q_n_group, self.window_size, q_n_group, self.window_size, d | |
| ).permute(0, 1, 2, 4, 3, 5, 6) | |
| # q_tiles: b x h x n_group x n_group x w^2 x d | |
| q_tiles = q_tiles.contiguous().view(b, h, q_n_group, q_n_group, -1, d) | |
| # k_tiles: b x h x n_group x n_group x 9w^2 x d | |
| k_tiles = ( | |
| self.get_overlapping_tiles(k) | |
| .contiguous() | |
| .view(b, h, q_n_group, q_n_group, -1, d) | |
| ) | |
| # dot_tiles: b x h x n_group x n_group x w^2 x 9w^2 | |
| dot_tiles = q_tiles.matmul(k_tiles.transpose(-1, -2)) | |
| # fill "-inf" into the zero-padding parts | |
| dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width, kv_width) | |
| dot_tiles[:, :, 0, :, :, :side_size].fill_(float("-inf")) | |
| dot_tiles[:, :, -1, :, :, -side_size:].fill_(float("-inf")) | |
| dot_tiles[:, :, :, 0, :, :, :side_size].fill_(float("-inf")) | |
| dot_tiles[:, :, :, -1, :, :, -side_size:].fill_(float("-inf")) | |
| dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width**2) | |
| return dot_tiles | |
| def get_overlapping_tiles(self, x): | |
| """Get overlapping tiles in the 2D spatial domain, ensuring each query computes correlation with all neighbors""" | |
| # x: b h (l w) d | |
| b, h, _, d = x.shape | |
| side_size = max(self.window_size // 2, 1) | |
| total_size = 2 * side_size + self.window_size | |
| kv_group_size = self.window_size | |
| kv_width = self.img_size | |
| x = x.view(b, h, kv_width, kv_width, d) | |
| x = F.pad(x, [0, 0, side_size, side_size, side_size, side_size], value=0) | |
| out_shape = [ | |
| b, | |
| h, | |
| kv_width // kv_group_size, | |
| kv_width // kv_group_size, | |
| total_size, | |
| total_size, | |
| d, | |
| ] | |
| in_stride = x.stride() | |
| out_stride = [ | |
| in_stride[0], | |
| in_stride[1], | |
| in_stride[2] * kv_group_size, | |
| in_stride[3] * kv_group_size, | |
| in_stride[2], | |
| in_stride[3], | |
| in_stride[4], | |
| ] | |
| # note we ignored the boundary here | |
| return x.as_strided(size=out_shape, stride=out_stride) | |
| def compute_window_pv(self, attn, v): | |
| """Compute the inner product of attention matrix and the values for the window attention.""" | |
| b, h, n_group, _, w2, n_k = attn.shape | |
| d = v.shape[-1] | |
| v_tiles = ( | |
| self.get_overlapping_tiles(v) | |
| .contiguous() | |
| .view(b, h, n_group, n_group, -1, d) | |
| ) | |
| # b x h x n_group x n_group x w^2 x d | |
| pv = attn.matmul(v_tiles) | |
| # return: b x h x (lw) x d | |
| ret = self.ungroup_dots(pv) | |
| return ret | |
| def group_dots(self, dots): | |
| b, h = dots.shape[:2] | |
| n_group = self.img_size // self.window_size | |
| dots = dots.reshape( | |
| b, h, n_group, self.window_size, n_group, self.window_size, -1 | |
| ).permute(0, 1, 2, 4, 3, 5, 6) | |
| dots = dots.contiguous().view( | |
| b, h, n_group, n_group, self.window_size * self.window_size, -1 | |
| ) | |
| return dots | |
| def ungroup_dots(self, dots): | |
| b, h, n_group, _, _, n_keys = dots.shape | |
| dots = dots.reshape( | |
| b, h, n_group, n_group, self.window_size, self.window_size, -1 | |
| ).permute(0, 1, 2, 4, 3, 5, 6) | |
| dots = dots.contiguous().view(b, h, -1, n_keys) | |
| return dots | |