prekshyam commited on
Commit
4283e0c
·
verified ·
1 Parent(s): 35ade0c

Added the Visualizer, Transformer, MAE, and a test image

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. ModelVisualizer.py +89 -0
  3. guineapig.jpg +3 -0
  4. maevit.py +246 -0
  5. transformer.py +239 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ guineapig.jpg filter=lfs diff=lfs merge=lfs -text
ModelVisualizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchvision.transforms as transforms
4
+
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import maevit import MAEViT
8
+
9
+
10
+ def visualize(model_path, img_path, figure_name):
11
+
12
+ model = MAEViT(
13
+ image_size=224,
14
+ patch_size=16,
15
+ embed_dim=128,
16
+ encoder_layers=2,
17
+ encoder_heads=4,
18
+ mlp_ratio=2.0,
19
+ mask_ratio=0.75,
20
+ decoder_embed_dim=64,
21
+ decoder_layers=2,
22
+ decoder_heads=4,
23
+ dropout=0.1
24
+ )
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model.to(device)
27
+ checkpoint = torch.load(model_path, map_location=device)
28
+ model.load_state_dict(checkpoint)
29
+ model.eval()
30
+
31
+
32
+ to_tensor = transforms.Compose([
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(
36
+ mean=[0.485, 0.456, 0.406],
37
+ std =[0.229, 0.224, 0.225]
38
+ )
39
+ ])
40
+
41
+
42
+ img = Image.open(img_path).convert('RGB')
43
+ x = to_tensor(img).unsqueeze(0).to(device) # [1,3,224,224]
44
+
45
+
46
+ with torch.no_grad():
47
+
48
+ x_enc, mask, ids_restore = model.forward_encoder(x)
49
+
50
+ x_rec_patches = model.forward_decoder(x_enc, ids_restore)
51
+
52
+
53
+ img_rec = model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS # [1,3,224,224]
54
+ img_patches = model.patchify(x) # [1, num_patches, patch_dim]
55
+
56
+ masked_patches = img_patches.clone()
57
+ mask = mask.unsqueeze(-1).to(torch.bool) # [1, num_patches, 1]
58
+ # masked_patches[mask] = 0
59
+ masked_patches = masked_patches.masked_fill(mask, 0)
60
+
61
+ img_masked = model.unpatchify(masked_patches) # [1,3,224,224]
62
+
63
+ inv_normalize = transforms.Normalize(
64
+ mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
65
+ std =[1/s for s in (0.229,0.224,0.225)]
66
+ )
67
+ def to_img(tensor):
68
+ img = tensor.squeeze(0).cpu()
69
+ img = inv_normalize(img)
70
+ img = img.permute(1,2,0).clamp(0,1).numpy()
71
+ return img
72
+
73
+ orig_np = to_img(x)
74
+ masked_np = to_img(img_masked)
75
+ recon_np = to_img(img_rec)
76
+
77
+ # 8. Plot
78
+ fig, axes = plt.subplots(1, 3, figsize=(15,5))
79
+ for ax, im, title in zip(axes,
80
+ [orig_np, masked_np, recon_np],
81
+ ['Original', 'Masked Input', 'Reconstruction']):
82
+ ax.imshow(im)
83
+ ax.set_title(title)
84
+ ax.axis('off')
85
+ plt.tight_layout()
86
+ plt.show()
87
+ plt.savefig(figure_name)
88
+
89
+ visualize('MAE1.bin', img_path='guineapig.jpg', figure_name='figures/MAE_visualization1.png')
guineapig.jpg ADDED

Git LFS Details

  • SHA256: ed50506aea7fd0841fa80732e2ceaa5b881b50ca2f55652f567ff8343a130747
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
maevit.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ import json
5
+ import os
6
+ import random
7
+ import torch
8
+ import numpy as np
9
+ from transformer import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
10
+
11
+ #includes both MAE and Vision Transformer for pretraining
12
+ class MAEViT(nn.Module):
13
+ """
14
+ Masked Autoencoder (MAE) for Vision Transformer.
15
+ Encoder sees only a fraction of patches; decoder reconstructs all patches.
16
+ """
17
+ def __init__(
18
+ self,
19
+ # default values for ViT-B-16
20
+ image_size: int = 224,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ encoder_layers: int = 12,
25
+ encoder_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ mask_ratio: float = 0.75,
28
+ decoder_embed_dim: int = 512,
29
+ decoder_layers: int = 8,
30
+ decoder_heads: int = 16,
31
+ dropout: float = 0.0,
32
+ ):
33
+ super().__init__()
34
+ assert image_size % patch_size == 0, "Image size must be divisible by patch size"
35
+ self.in_chans = in_chans
36
+ self.image_size = image_size
37
+ self.patch_size = patch_size
38
+
39
+ #Conv2d trick to PATCHIFY AND EMBED (DIFFERENT FROM THE PATCHIFY Function
40
+ #which is used in validation)
41
+ self.conv_proj = nn.Conv2d(
42
+ in_channels = in_chans,
43
+ out_channels = embed_dim, #embed_dim is for the TOTAL; this is patch_dimen^2 * 3 (# of color channels)
44
+ kernel_size = patch_size, #this is so that the kernel is basically the patch (a square)
45
+ stride = patch_size #this ensures that the kernel moves so that the patches do not overlap
46
+ )
47
+ num_patches = (image_size // patch_size) ** 2 #just the number of patches since image_size // patch_size deals with only the dimension
48
+ self.mask_ratio = mask_ratio #75% is masked for best results with MAE
49
+
50
+ #set CLS token, a class token that contains a learnable vector that will eventually contain embeddings for the whole image
51
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
52
+ nn.init.normal_(self.cls_token, std = 0.02) #normal distribution
53
+
54
+ #Transformer encoder: learns contextual relationships b/t patches, generates embeddings
55
+ enc_layer = TransformerEncoderLayer(
56
+ embed_dim = embed_dim,
57
+ num_heads = encoder_heads, #for multihead attn
58
+ mlp_dim = int(embed_dim * mlp_ratio),
59
+ dropout = dropout #used in MLP
60
+ )
61
+ self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim) #does self attn & feed forward
62
+
63
+ #Encoder -> Decoder (Linear Projection)
64
+ self.enc_to_dec = nn.Linear(embed_dim, decoder_embed_dim, bias = False)
65
+
66
+ #Decoder mask token (learnable placeholder token at each masked patch, helps decoder reconstruct those patches) and positional embedding generated
67
+ self.dec_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) #mask tokens originally set to zero
68
+ self.dec_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, decoder_embed_dim)) #num_patches + 1 includes cls token
69
+ self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim))
70
+
71
+ nn.init.normal_(self.dec_mask_token, std=0.02)
72
+ nn.init.normal_(self.enc_pos_embed, std=0.02)
73
+ nn.init.normal_(self.dec_pos_embed, std=0.02)
74
+
75
+ #Decoder for the transformer: predicts the masked patches
76
+ dec_layer = TransformerDecoderLayer(
77
+ encoder_embed_dim = embed_dim,
78
+ decoder_embed_dim = decoder_embed_dim,
79
+ num_heads = decoder_heads,
80
+ mlp_dim = int(decoder_embed_dim * mlp_ratio),
81
+ dropout = dropout #a regularizer to prevent model from overfitting and possibly making decisions based on noise
82
+ )
83
+ self.decoder = TransformerDecoder(dec_layer, decoder_layers, embed_dim = decoder_embed_dim)
84
+
85
+ #Reconstruction for masked patches
86
+ self.pred = nn.Linear(decoder_embed_dim, (patch_size ** 2) * in_chans)
87
+
88
+ self.norm = nn.LayerNorm(embed_dim)
89
+
90
+ #patchify: converts image into tensors for patches
91
+ def patchify(self, imgs):
92
+ """
93
+ imgs: (B, C, H, W)
94
+ returns: (B, N, patch_size * patch_size * C)
95
+ """
96
+ B, C, H, W = imgs.shape
97
+ p = self.patch_size
98
+ assert H % p == 0 and W % p == 0, "Image dimensions must be divisible by the patch size."
99
+
100
+ h = H // p
101
+ w = W // p
102
+ patches = imgs.reshape(B, C, h, p, w, p)
103
+ patches = torch.einsum('nchawb->nhwabc', patches)
104
+ patches = patches.reshape(B, h * w, p * p * C)
105
+ #print("size of patches: ")
106
+ #print(patches.size())
107
+ return patches
108
+
109
+ #unpatchify: helps reconstruct image from patches (tensors -> images)
110
+ #is not actually needed, maybe for debugging
111
+ def unpatchify(self, x):
112
+ #x is a tensor of shape (B, num_patches, patch_size*patch_size*in_chans)
113
+ #x represents flattened pixel values
114
+ #imgs: (returned) has shape (B, in_chans, img_size, img_size)
115
+ patch_dimen = self.patch_size
116
+ h = int(x.shape[1]**0.5)
117
+ w = h
118
+ assert h * w == x.shape[1]
119
+ x = x.reshape(x.shape[0], h, w, patch_dimen, patch_dimen, self.in_chans)
120
+ x = torch.einsum('nhwpqc->nchpwq', x)
121
+ imgs = x.reshape([x.shape[0], self.in_chans, self.image_size, self.image_size])
122
+ return imgs
123
+
124
+ def random_masking(self, x):
125
+ """
126
+ Perform per-sample random masking by shuffling.
127
+ returns:
128
+ x_masked: Tensor with visible patches
129
+ mask: Tensor indicating which patches are visible (0) or masked (1)
130
+ ids_restore: Tensor to restore original order of patches
131
+ """
132
+ B, L, D = x.shape
133
+ #number of patches to keep
134
+ len_keep = int(L*(1 - self.mask_ratio))
135
+
136
+ #indices for visible patches by generating noise
137
+ noise = torch.rand(B, L, device=x.device)
138
+ ids_shuffle = torch.argsort(noise, dim=1)
139
+
140
+ #restore indices for unshuffling patches
141
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
142
+
143
+ #indices of kept patches
144
+ ids_keep = ids_shuffle[:, :len_keep]
145
+ #visible patches gathered
146
+ x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))
147
+
148
+ #binary mask for patch embedding (1 is for masked, 0 is for visible)
149
+ mask = torch.ones(B, L, device=x.device)
150
+ #mask is unshuffled back into original patch order
151
+ mask[:, :len_keep] = 0 ## DONGHEE: THIS PART WAS MISSING IN THE ORIGINAL CODE
152
+ mask = torch.gather(mask, 1, ids_restore)
153
+
154
+ return x_masked, mask, ids_restore, ids_keep
155
+
156
+ def forward_encoder(self, imgs):
157
+
158
+ # 1. Patch embedding
159
+ x = self.conv_proj(imgs) # [B, embed_dim, H/ps, W/ps]
160
+ x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
161
+ x = self.norm(x) # [B, N, embed_dim]
162
+ B, N, D = x.shape
163
+
164
+ # 2. Add positional embeddings (w/o class token)
165
+ #print(x.shape)
166
+ #print(self.enc_pos_embed.shape)
167
+ x = x + self.enc_pos_embed[:, 1:, :]
168
+
169
+ # 3. Random masking
170
+ x_masked, mask, ids_restore, ids_keep = self.random_masking(x)
171
+
172
+ # 4. Encoder input (cls token + visible patches)
173
+ cls_token = self.cls_token + self.enc_pos_embed[:, :1, :] # class token with positional embedding
174
+ cls_tokens = cls_token.expand(B, -1, -1) # repeat for batch size
175
+ x_enc = torch.cat([cls_tokens, x_masked], dim=1)
176
+
177
+ # 5. Encoder forward
178
+ x_enc = self.encoder(x_enc) # TO DO
179
+
180
+ return x_enc, mask, ids_restore
181
+
182
+ def forward_decoder(self, x_enc, ids_restore):
183
+ # encoder output needs to be projected to decoder embedding space
184
+ x_dec = self.enc_to_dec(x_enc)
185
+
186
+ #sequence unshuffled to original order
187
+ B, L, D = x_dec.shape
188
+ mask_tokens = self.dec_mask_token.repeat(B, ids_restore.shape[1] + 1 - x_dec.shape[1], 1)
189
+
190
+ #concatenate output from heads?
191
+ x_no_cls = torch.cat([x_dec[:, 1:, :], mask_tokens], dim=1)
192
+ x_no_cls = torch.gather(x_no_cls, 1, ids_restore.unsqueeze(-1).repeat(1, 1, D))
193
+ x_dec = torch.cat([x_dec[:, :1, :], x_no_cls], dim=1)
194
+
195
+ #add positional embeddings
196
+ x_dec = x_dec + self.dec_pos_embed[:, :x_dec.size(1), :]
197
+
198
+ #decoder forward
199
+ x_dec = self.decoder(x_dec, x_enc)
200
+
201
+ #predict pixels (without class token)
202
+ x_rec = self.pred(x_dec)
203
+
204
+ return x_rec
205
+
206
+ def compute_mae_loss(self, imgs, pred, mask):
207
+ """
208
+ Mean Squared Error loss for masked patches
209
+ imgs: [N, 3, H, W]
210
+ pred: [N, L, p*p*3]
211
+ mask: [N, L], 0 is keep, 1 is remove,
212
+ """
213
+
214
+ #mask: binary mask tensor
215
+ target = self.patchify(imgs)
216
+ #print("target size: ")
217
+ #print(target.size())
218
+ #print("pred size: ")
219
+ #print(pred.size())
220
+
221
+ pred = pred[:, 1:, :]
222
+ loss = (pred - target)**2
223
+ loss = loss.mean(dim=-1)
224
+ #we don't want to calculate loss on visible patches, only masked patches
225
+ loss = (loss * mask).sum() / (mask.sum() + 1e-6)
226
+
227
+ return loss
228
+
229
+ def forward(self, imgs):
230
+ """
231
+ Forward pass for MAE: encode, decode, and compute reconstruction loss.
232
+ imgs: [B, 3, H, W]
233
+ returns: reconstruction loss
234
+ """
235
+
236
+ # 1. Forward encoder
237
+ x_enc, mask, ids_restore = self.forward_encoder(imgs)
238
+
239
+ #x_enc = self.enc_to_dec(x_enc)
240
+
241
+ # 2. Forward decoder
242
+ x_rec = self.forward_decoder(x_enc, ids_restore)
243
+
244
+ loss = self.compute_mae_loss(imgs, x_rec, mask)
245
+
246
+ return loss
transformer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import copy
4
+ import math
5
+
6
+ class MultiHeadAttention(nn.Module):
7
+ """
8
+ Multi-Head Attention:
9
+ 1) Linear projections for Q, K, V
10
+ 2) Scaled dot-product attention per head
11
+ 3) Concatenate heads and final linear projection
12
+ https://arxiv.org/pdf/1706.03762
13
+ """
14
+ def __init__(self, embed_dim:int, key_dim:int, num_heads: int, dropout: float = 0.0):
15
+ super().__init__()
16
+
17
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
18
+ self.embed_dim = embed_dim
19
+ self.key_dim = key_dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = embed_dim // num_heads
22
+ self.scale = math.sqrt(self.head_dim) # square root of dk for scaling
23
+
24
+ # Separate projections for query, key, and value 3 diff transformations
25
+ # HINT: Linear projections for Q, K, V
26
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
27
+ self.k_proj = nn.Linear(key_dim, embed_dim) #To Do
28
+ self.v_proj = nn.Linear(key_dim, embed_dim) #To Do
29
+
30
+ # Output projection after concatenating heads (embed_dim -> embed_dim)
31
+ self.out_proj = nn.Linear(embed_dim, embed_dim) #To Do
32
+
33
+ # Dropouts
34
+ self.attn_dropout = nn.Dropout(dropout)
35
+ self.proj_dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ Args:
40
+ query: [batch, seq_q, embed_dim]
41
+ key: [batch, seq_k, embed_dim]
42
+ value: [batch, seq_k, embed_dim]
43
+ Returns:
44
+ out: [batch, seq_q, embed_dim]
45
+ """
46
+ B, seq_q, _ = query.size()
47
+ _, seq_k, _ = key.size()
48
+
49
+ # 1) Project inputs and split into heads
50
+ q = self.q_proj(query).view(B, seq_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, seq_q, head_dim]
51
+ k = self.k_proj(key).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim]
52
+ v = self.v_proj(value).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim]
53
+
54
+ # 2) Compute scaled dot-product attention
55
+ scores = (q @ k.transpose(-1, -2)) / self.scale # TO DO multiply q and k, then scale # [B, heads, seq_q, seq_k] swaps last twp dims?
56
+ weights = torch.softmax(scores, dim=-1) # TO DO apply softmax to scores
57
+ weights = self.attn_dropout(weights)
58
+ attn = weights @ v # TO DO multiply weights and v # [B, heads, seq_q, head_dim]
59
+
60
+ # 3) Concatenate heads
61
+ attn = attn.transpose(1, 2).contiguous().view(B, seq_q, self.embed_dim) # [B, seq_q, embed_dim]
62
+
63
+ # 4) Final projection
64
+ out = self.out_proj(attn) # TO DO apply output projection
65
+ out = self.proj_dropout(out)
66
+
67
+ return out
68
+
69
+
70
+ class TransformerEncoderLayer(nn.Module):
71
+ """
72
+ Transformer Encoder Layer:
73
+ 1) Multi-head self-attention
74
+ 2) Feed-forward network
75
+ 3) Residual connections + LayerNorm
76
+ """
77
+ def __init__(
78
+ self,
79
+ embed_dim: int,
80
+ num_heads: int,
81
+ mlp_dim: int,
82
+ dropout: float = 0.1,
83
+ ):
84
+ super().__init__()
85
+
86
+ # 1) Self-attention
87
+ self.self_attn = MultiHeadAttention(
88
+ embed_dim=embed_dim,
89
+ key_dim=embed_dim, # self-attention uses same dimension for Q, K, V
90
+ num_heads=num_heads,
91
+ dropout=dropout,
92
+ )
93
+
94
+ # 2) Feed-forward network using nn.Sequential
95
+ self.ffn = nn.Sequential(
96
+ nn.Linear(embed_dim, mlp_dim),
97
+ nn.GELU(),
98
+ nn.Dropout(dropout),
99
+ nn.Linear(mlp_dim, embed_dim),
100
+ )
101
+
102
+ # 3) LayerNorm and Dropouts for residuals
103
+ self.norm1 = nn.LayerNorm(embed_dim)
104
+ self.norm2 = nn.LayerNorm(embed_dim)
105
+ # self.norm3 = nn.LayerNorm(embed_dim)
106
+ self.attn_dropout = nn.Dropout(dropout)
107
+ self.ff_dropout = nn.Dropout(dropout)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Args:
112
+ x: Tensor of shape [batch, seq_len, embed_dim]
113
+ Returns:
114
+ Tensor of same shape
115
+ """
116
+ # 1) Self-attention block
117
+ x_norm = self.norm1(x) # Normalize input
118
+ attn_out = self.self_attn(x_norm, x_norm, x_norm)
119
+ x = x + self.attn_dropout(attn_out) # Residual connection
120
+
121
+ # 2) Feed-forward block
122
+ x_norm_ff = self.norm2(x)
123
+ ff = self.ffn(x_norm_ff)
124
+ x = x + self.ff_dropout(ff) # Residual connection
125
+ # x = self.norm3(x)
126
+
127
+ return x
128
+
129
+
130
+ class TransformerEncoder(nn.Module):
131
+
132
+ def __init__(self, encoder_layer: TransformerEncoderLayer, num_layers: int, embed_dim: int):
133
+ super().__init__()
134
+
135
+ # Clone the provided encoder_layer num_layers times
136
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
137
+ self.norm = nn.LayerNorm(embed_dim)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+
141
+ # TO DO pass through each layer
142
+ # HINT: Pass input through each encoder layer to update x
143
+ for layer in self.layers:
144
+ x = layer(x)
145
+
146
+ # Apply final normalization
147
+ x = self.norm(x)
148
+
149
+ return x
150
+
151
+
152
+ class TransformerDecoderLayer(nn.Module):
153
+ """
154
+ Transformer Decoder Layer:
155
+ 1) Self-attention
156
+ 2) Cross-attention (over encoder features)
157
+ 3) Feed-forward network
158
+ 4) Residual connections + LayerNorm
159
+ """
160
+ def __init__(
161
+ self,
162
+ encoder_embed_dim: int,
163
+ decoder_embed_dim: int,
164
+ num_heads: int,
165
+ mlp_dim: int,
166
+ dropout: float = 0.1,
167
+ ):
168
+ super().__init__()
169
+
170
+ # 1. Self-attention on decoder input
171
+ self.self_attn = MultiHeadAttention(decoder_embed_dim, decoder_embed_dim, num_heads, dropout)
172
+
173
+ # 2. Cross-attention over encoder features
174
+ self.cross_attn = MultiHeadAttention(decoder_embed_dim, encoder_embed_dim, num_heads, dropout)
175
+
176
+ # 3. Feed-forward network
177
+ self.ffn = nn.Sequential(
178
+ nn.Linear(decoder_embed_dim, mlp_dim),
179
+ nn.GELU(),
180
+ nn.Dropout(dropout),
181
+ nn.Linear(mlp_dim, decoder_embed_dim)
182
+ )
183
+
184
+ # 4. LayerNorms and Dropouts
185
+ self.norm1 = nn.LayerNorm(decoder_embed_dim)
186
+ self.norm2 = nn.LayerNorm(decoder_embed_dim)
187
+ self.norm3 = nn.LayerNorm(decoder_embed_dim)
188
+ self.dropout1 = nn.Dropout(dropout)
189
+ self.dropout2 = nn.Dropout(dropout)
190
+ self.dropout3 = nn.Dropout(dropout)
191
+
192
+ def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
193
+
194
+ # 1) Self-attention block
195
+ x_norm = self.norm1(x) # Normalize input
196
+ sa = self.self_attn(x_norm, x_norm, x_norm)
197
+ # TO DO
198
+ x = x + self.dropout1(sa)
199
+
200
+ # 2) Cross-attention block
201
+ x_norm_ca = self.norm2(x)
202
+ ca = self.cross_attn(x_norm_ca, encoder_features, encoder_features)
203
+ # TO DO
204
+ x = x + self.dropout2(ca)
205
+
206
+ # 3) Feed-forward block
207
+ x_norm_ff = self.norm3(x)
208
+ ff = self.ffn(x_norm_ff)
209
+ # TO DO
210
+ x = x + self.dropout3(ff)
211
+
212
+ return x
213
+
214
+
215
+ class TransformerDecoder(nn.Module):
216
+
217
+ def __init__(
218
+ self,
219
+ decoder_layer: TransformerDecoderLayer,
220
+ num_layers: int,
221
+ embed_dim: int,
222
+ ):
223
+ super().__init__()
224
+
225
+ # Clone the provided decoder_layer num_layers times
226
+ self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
227
+
228
+ self.norm = nn.LayerNorm(embed_dim)
229
+
230
+ def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
231
+
232
+ # TODO pass through each layer
233
+ # HINT: Pass input through each encoder layer to update x
234
+ for layer in self.layers:
235
+ x = layer(x, encoder_features)
236
+
237
+ x = self.norm(x)
238
+
239
+ return x