File size: 4,269 Bytes
cce011e
 
5feebb1
cce011e
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
5feebb1
cce011e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn 
import math

DEBUG = False

class MultiHeadSelfAttention(nn.Module):
    
    def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
        
        super(MultiHeadSelfAttention, self).__init__()
        
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        
        self.q_w, self.k_w, self.v_w, self.out_w = (nn.Linear(embedding_dim, embedding_dim) for _ in range(4))

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:

        if DEBUG: print(f'MSA Input shape (Q, K, V): {q.shape}: [batch_size, n_patches, embedding_dim]')

        # Linear projections for Q, K, V
        if DEBUG: print(f'Linear projection for Q, K, V: {q.shape} [batch_size, n_patches, embedding_dim]')
        q = self.q_w(q).view(*q.shape[:-1], self.num_heads, self.head_dim)
        k = self.k_w(k).view(*k.shape[:-1], self.num_heads, self.head_dim)
        v = self.q_w(v).view(*v.shape[:-1], self.num_heads, self.head_dim)
        if DEBUG: print(f'Splitting the last dimension once for each head: {q.shape} [batch_size, n_patches, num_heads, head_dim]')

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        if DEBUG: print(f'Swap patches and head to have the head come first: {q.shape} [batch_size, num_heads, n_patches, head_dim]')

        attention_scores = torch.matmul(q, k.mT) / math.sqrt(self.head_dim)
        if DEBUG: print(f'Compute attention scores for each head (scaled dot product): {attention_scores.shape} [batch_size, num_heads, n_patches, n_patches]')

        attention_weights = torch.softmax(attention_scores, dim=-1)
        if DEBUG: print(f'Softmax of attention scores: {attention_weights.shape} [batch_size, num_batches, n_patches, n_patches]')

        weighted_sum = torch.matmul(attention_weights, v)
        if DEBUG: print(f'Weighted sum of Values: {weighted_sum.shape} [batch_size, num_heads, n_patches, head_dim]')

        weighted_sum = weighted_sum.transpose(1, 2).contiguous()
        if DEBUG: print(f'Swap again the patches and the heads: {weighted_sum.shape} [batch_size, n_patches, num_heads, head_dim]')

        weighted_sum = weighted_sum.view(*weighted_sum.shape[:-2], -1)
        if DEBUG: print(f'Recover the original dimensions by merging the last 2: {weighted_sum.shape} [batch_size, n_patches, embedding_dim]')

        output = self.out_w(weighted_sum)
        if DEBUG: print(f'(Output) Linear projection of the weighted sum: {output.shape} [batch_size, num_heads, n_patches, embedding_dim]')
        
        return output
        

class MSABlock(nn.Module):

    def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
        super(MSABlock, self).__init__()
        self.msa = MultiHeadSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer_norm(x)
        return self.msa(x, x, x)

class MLPBlock(nn.Module):

    def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072) -> None:
        super(MLPBlock, self).__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=hidden_size),
            nn.GELU(),
            nn.Linear(in_features=hidden_size, out_features=embedding_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(self.layer_norm(x))


class TransformerEncoderBlock(nn.Module):
    
    def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072, num_heads: int = 12) -> None:
        super(TransformerEncoderBlock, self).__init__()
        self.msa = MSABlock(embedding_dim=embedding_dim, num_heads=num_heads)
        self.mlp = MLPBlock(embedding_dim=embedding_dim, hidden_size=hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.msa(x) + x
        x = self.mlp(x) + x
        return x

if __name__ == '__main__':

    DEBUG = True
    x = torch.rand(5, 197, 768)
    msa = MultiHeadSelfAttention()
    out = msa(x,x,x)
    print(out.shape)