Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from timm.models.layers import DropPath | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout, out_dim=None): | |
| super().__init__() | |
| self.fc1 = nn.Linear(dim, hidden_dim) | |
| self.act = nn.GELU() | |
| if out_dim is None: | |
| out_dim = dim | |
| self.fc2 = nn.Linear(hidden_dim, out_dim) | |
| self.drop = nn.Dropout(dropout) | |
| def unwrapped(self): | |
| return self | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads, dropout): | |
| super().__init__() | |
| self.heads = heads | |
| head_dim = dim // heads | |
| self.scale = head_dim**-0.5 | |
| self.attn = None | |
| self.qkv = nn.Linear(dim, dim * 3) | |
| self.attn_drop = nn.Dropout(dropout) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(dropout) | |
| def unwrapped(self): | |
| return self | |
| def forward(self, x, mask=None): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = ( | |
| qkv[0], | |
| qkv[1], | |
| qkv[2], | |
| ) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x, attn | |
| class Block(nn.Module): | |
| def __init__(self, dim, heads, mlp_dim, dropout, drop_path): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.attn = Attention(dim, heads, dropout) | |
| self.mlp = FeedForward(dim, mlp_dim, dropout) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| def forward(self, x, mask=None, return_attention=False): | |
| y, attn = self.attn(self.norm1(x), mask) | |
| if return_attention: | |
| return attn | |
| x = x + self.drop_path(y) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |