prekshyam commited on
Commit
024d344
·
verified ·
1 Parent(s): 54c33cb

Added missing ViTForEmotionClassification class

Browse files
Files changed (1) hide show
  1. maevit.py +69 -0
maevit.py CHANGED
@@ -245,6 +245,75 @@ class MAEViT(nn.Module):
245
 
246
  return loss
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  class ViTForEmotionClassificationMLP(ViTForEmotionClassification):
249
  """
250
  Replace the linear head with MLP
 
245
 
246
  return loss
247
 
248
+ # for finetuning
249
+ class ViTForEmotionClassification(nn.Module):
250
+ """
251
+ ViT For classification
252
+ Encoder only
253
+ """
254
+ def __init__(
255
+ self,
256
+ # default values for ViT-B-16
257
+ image_size: int = 224,
258
+ patch_size: int = 16,
259
+ in_chans: int = 3,
260
+ embed_dim: int = 768,
261
+ encoder_layers: int = 12,
262
+ encoder_heads: int = 12,
263
+ mlp_ratio: float = 4.0,
264
+ dropout: float = 0.0,
265
+ num_classes: int = 9, # Number of emotion classes #changed by Preksha was originally 7
266
+ ):
267
+ super().__init__()
268
+ assert image_size % patch_size == 0, "Image size must be divisible by patch size"
269
+ self.patch_size = patch_size
270
+
271
+ self.conv_proj = nn.Conv2d(
272
+ in_channels = in_chans,
273
+ out_channels = embed_dim, #embed_dim is for the TOTAL; this is patch_dimen^2 * 3 (# of color channels)
274
+ kernel_size = patch_size, #this is so that the kernel is basically the patch (a square)
275
+ stride = patch_size #this ensures that the kernel moves so that the patches do not overlap
276
+ )
277
+
278
+ num_patches = (image_size // patch_size) ** 2
279
+ self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim))
280
+ nn.init.normal_(self.enc_pos_embed, std=0.02)
281
+
282
+ # set CLS token, a class token that contains a learnable vector that will eventually contain embeddings for the whole image
283
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
284
+ nn.init.normal_(self.cls_token, std = 0.02) #normal distribution
285
+
286
+ #Transformer encoder: learns contextual relationships b/t patches, generates embeddings
287
+ enc_layer = TransformerEncoderLayer(
288
+ embed_dim = embed_dim,
289
+ num_heads = encoder_heads, #for multihead attn
290
+ mlp_dim = int(embed_dim * mlp_ratio),
291
+ dropout = dropout #used in MLP
292
+ )
293
+ self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim) #does self attn & feed forward
294
+
295
+ self.norm = nn.LayerNorm(embed_dim)
296
+ self.head_norm = nn.LayerNorm(embed_dim)
297
+ self.head = nn.Linear(embed_dim, num_classes) # 9 emotions
298
+
299
+ def forward(self, imgs):
300
+
301
+ # 1. Patch embedding
302
+ x = self.conv_proj(imgs) # [B, embed_dim, H/ps, W/ps]
303
+ x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
304
+ x = self.norm(x) # [B, N, embed_dim]
305
+ B, N, D = x.shape
306
+
307
+ cls_tokens = self.cls_token.expand(B, -1, -1) # repeat for batch size
308
+ x = torch.cat([cls_tokens, x], dim=1) # [B, N+1, embed_dim]
309
+ x = x + self.enc_pos_embed
310
+
311
+ x = self.encoder(x)
312
+
313
+ logits = self.head(self.head_norm(x[:, 0])) # Use the class token for classification
314
+
315
+ return logits
316
+
317
  class ViTForEmotionClassificationMLP(ViTForEmotionClassification):
318
  """
319
  Replace the linear head with MLP