Spaces:
Running
on
T4
Running
on
T4
| import ipdb | |
| import torch.nn as nn | |
| from xformers.ops import memory_efficient_attention | |
| class MEAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_norm=False, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| norm_layer=nn.LayerNorm, | |
| ): | |
| super().__init__() | |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B, N, 3, self.num_heads, self.head_dim) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = qkv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D] | |
| x = memory_efficient_attention( | |
| q.transpose(1, 2), | |
| k.transpose(1, 2), | |
| v.transpose(1, 2), | |
| scale=self.scale, | |
| ) | |
| x = x.reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |