File size: 2,704 Bytes
582b238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from RoPE import apply_angles_2d, generate_angles_2d
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class Attention(nn.Module):
    def __init__(self, H,W, emb_dim, n_heads=8):
        super().__init__()
        self.H = H
        self.W = W
        self.n_heads = n_heads
        head_dim = emb_dim // n_heads
        self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
        self.apply_angles_2d = apply_angles_2d
        self.proj = nn.Linear(emb_dim, emb_dim)
        self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)

    def forward(self, x):
        B, N, D = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)

        # to 2D
        q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
        k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
        v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)

        q = apply_angles_2d(q, self.freq)
        k = apply_angles_2d(k, self.freq)
        v = apply_angles_2d(v, self.freq)

        # to 1D
        q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
        k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
        v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)

        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "B h N D -> B N (h D)")
        x = self.proj(x)
        return x

class ViTBlock(nn.Module):
  def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
    self.H, self.W, self.emb_dim = H, W, emb_dim
    super().__init__()
    self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
                              Attention(H,W,emb_dim,n_heads=n_heads))
    self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
                             nn.Linear(emb_dim, emb_dim*4, bias=True),
                             nn.GELU(),
                             nn.Dropout(dropout),
                             nn.Linear(emb_dim*4, emb_dim, bias=True),
                             nn.Dropout(dropout))
  def forward(self, x):
    assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
    assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
    x = x + self.attn(x)
    x = x + self.MLP(x)
    return x
  
# Sanity Check :)
print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape)