| | import torch |
| | import torch.nn as nn |
| | import math |
| |
|
| | DEBUG = False |
| |
|
| | class MultiHeadSelfAttention(nn.Module): |
| | |
| | def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None: |
| | |
| | super(MultiHeadSelfAttention, self).__init__() |
| | |
| | self.num_heads = num_heads |
| | self.head_dim = embedding_dim // num_heads |
| | |
| | self.q_w, self.k_w, self.v_w, self.out_w = (nn.Linear(embedding_dim, embedding_dim) for _ in range(4)) |
| |
|
| | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| |
|
| | if DEBUG: print(f'MSA Input shape (Q, K, V): {q.shape}: [batch_size, n_patches, embedding_dim]') |
| |
|
| | |
| | if DEBUG: print(f'Linear projection for Q, K, V: {q.shape} [batch_size, n_patches, embedding_dim]') |
| | q = self.q_w(q).view(*q.shape[:-1], self.num_heads, self.head_dim) |
| | k = self.k_w(k).view(*k.shape[:-1], self.num_heads, self.head_dim) |
| | v = self.q_w(v).view(*v.shape[:-1], self.num_heads, self.head_dim) |
| | if DEBUG: print(f'Splitting the last dimension once for each head: {q.shape} [batch_size, n_patches, num_heads, head_dim]') |
| |
|
| | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| | if DEBUG: print(f'Swap patches and head to have the head come first: {q.shape} [batch_size, num_heads, n_patches, head_dim]') |
| |
|
| | attention_scores = torch.matmul(q, k.mT) / math.sqrt(self.head_dim) |
| | if DEBUG: print(f'Compute attention scores for each head (scaled dot product): {attention_scores.shape} [batch_size, num_heads, n_patches, n_patches]') |
| |
|
| | attention_weights = torch.softmax(attention_scores, dim=-1) |
| | if DEBUG: print(f'Softmax of attention scores: {attention_weights.shape} [batch_size, num_batches, n_patches, n_patches]') |
| |
|
| | weighted_sum = torch.matmul(attention_weights, v) |
| | if DEBUG: print(f'Weighted sum of Values: {weighted_sum.shape} [batch_size, num_heads, n_patches, head_dim]') |
| |
|
| | weighted_sum = weighted_sum.transpose(1, 2).contiguous() |
| | if DEBUG: print(f'Swap again the patches and the heads: {weighted_sum.shape} [batch_size, n_patches, num_heads, head_dim]') |
| |
|
| | weighted_sum = weighted_sum.view(*weighted_sum.shape[:-2], -1) |
| | if DEBUG: print(f'Recover the original dimensions by merging the last 2: {weighted_sum.shape} [batch_size, n_patches, embedding_dim]') |
| |
|
| | output = self.out_w(weighted_sum) |
| | if DEBUG: print(f'(Output) Linear projection of the weighted sum: {output.shape} [batch_size, num_heads, n_patches, embedding_dim]') |
| | |
| | return output |
| | |
| |
|
| | class MSABlock(nn.Module): |
| |
|
| | def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None: |
| | super(MSABlock, self).__init__() |
| | self.msa = MultiHeadSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads) |
| | self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.layer_norm(x) |
| | return self.msa(x, x, x) |
| |
|
| | class MLPBlock(nn.Module): |
| |
|
| | def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072) -> None: |
| | super(MLPBlock, self).__init__() |
| | self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(in_features=embedding_dim, out_features=hidden_size), |
| | nn.GELU(), |
| | nn.Linear(in_features=hidden_size, out_features=embedding_dim) |
| | ) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.mlp(self.layer_norm(x)) |
| |
|
| |
|
| | class TransformerEncoderBlock(nn.Module): |
| | |
| | def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072, num_heads: int = 12) -> None: |
| | super(TransformerEncoderBlock, self).__init__() |
| | self.msa = MSABlock(embedding_dim=embedding_dim, num_heads=num_heads) |
| | self.mlp = MLPBlock(embedding_dim=embedding_dim, hidden_size=hidden_size) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.msa(x) + x |
| | x = self.mlp(x) + x |
| | return x |
| |
|
| | if __name__ == '__main__': |
| |
|
| | DEBUG = True |
| | x = torch.rand(5, 197, 768) |
| | msa = MultiHeadSelfAttention() |
| | out = msa(x,x,x) |
| | print(out.shape) |