| import math |
| from typing import Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.attention.flex_attention import ( |
| _DEFAULT_SPARSE_BLOCK_SIZE, |
| create_block_mask, |
| create_mask, |
| flex_attention, |
| ) |
| from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv |
|
|
| try: |
| from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex |
| except ImportError: |
| from torch._higher_order_ops.flex_attention import TransformGetItemToIndex |
| |
| from torch._dynamo import disable |
|
|
|
|
| def generate_alibi_bias(H=12): |
| alibi_bias = [] |
| for h in range(H): |
| alibi_bias.append(-((h + 1) / H)) |
| alibi_bias = torch.tensor(alibi_bias) |
| alibi_bias = torch.exp2(alibi_bias) |
| return alibi_bias |
|
|
|
|
| def get_rel_bias_func(scale, coords=None, qk_scale=1.0): |
| def patch_coords_rel_bias(score, b, h, q_idx, kv_idx): |
| if coords is None: |
| return score |
| with torch.no_grad(): |
| dx = coords[b, q_idx][0] - coords[b, kv_idx][0] |
| dy = coords[b, q_idx][1] - coords[b, kv_idx][1] |
| dist = torch.sqrt(dx * dx + dy * dy) |
| dist = dist.clamp(max=1000) |
| dist = torch.log1p(dist) |
| bias = dist * scale[h] * qk_scale |
| return score - bias |
| return patch_coords_rel_bias |
|
|
|
|
| def key_padding_mask(mask): |
| def padding_mask(b, h, q_idx, kv_idx): |
| return ~mask[b, kv_idx] |
| return padding_mask |
|
|
|
|
| class FlexCore(nn.Module): |
| """ |
| For using "forward hook" |
| """ |
| def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False): |
| """ |
| "return_lse=True" should be used with ATTN_MAP_VIS wrapper. |
| Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores). |
| """ |
| return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse) |
|
|
|
|
| class Flex_Attention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 12, |
| qkv_bias: bool = True, |
| proj_drop: float = 0., |
| use_rel_bias: bool = True, |
| ): |
| super().__init__() |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.scale = self.head_dim ** -0.5 |
| self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) |
| if qkv_bias: |
| self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) |
| else: |
| self.in_proj_bias = None |
| |
| self.f_attn = FlexCore() |
|
|
| self.out_proj = nn.Linear(dim, dim) |
| self.out_drop = nn.Dropout(proj_drop) |
| |
| self.max_distance=16 |
| |
| def build_rel_bias(self, coords): |
| return torch.log1p(torch.cdist(coords, coords, p=2)) |
|
|
| def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False): |
| N, L, C = x.shape |
| |
| x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous() |
| q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)] |
|
|
| q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
| k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
| v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous() |
| |
| if attn_mask is not None: |
| maks_func = create_block_mask( |
| key_padding_mask(attn_mask), N, self.num_heads, L, L |
| ) |
| |
| qk_scale = q.size(-1) ** -0.5 |
| x = self.f_attn( |
| q, k, v, |
| score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None, |
| block_mask = maks_func if attn_mask is not None else None, |
| return_lse=return_attn_score, |
| ) |
| |
| x = x.permute(0, 2, 1, 3).contiguous() |
| x = x.reshape(N, L, C).contiguous() |
|
|
| x = self.out_proj(x) |
| x = self.out_drop(x) |
| return x |
|
|