Spaces:
Sleeping
Sleeping
Added missing ViTForEmotionClassification class
Browse files
maevit.py
CHANGED
|
@@ -245,6 +245,75 @@ class MAEViT(nn.Module):
|
|
| 245 |
|
| 246 |
return loss
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
class ViTForEmotionClassificationMLP(ViTForEmotionClassification):
|
| 249 |
"""
|
| 250 |
Replace the linear head with MLP
|
|
|
|
| 245 |
|
| 246 |
return loss
|
| 247 |
|
| 248 |
+
# for finetuning
|
| 249 |
+
class ViTForEmotionClassification(nn.Module):
|
| 250 |
+
"""
|
| 251 |
+
ViT For classification
|
| 252 |
+
Encoder only
|
| 253 |
+
"""
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
# default values for ViT-B-16
|
| 257 |
+
image_size: int = 224,
|
| 258 |
+
patch_size: int = 16,
|
| 259 |
+
in_chans: int = 3,
|
| 260 |
+
embed_dim: int = 768,
|
| 261 |
+
encoder_layers: int = 12,
|
| 262 |
+
encoder_heads: int = 12,
|
| 263 |
+
mlp_ratio: float = 4.0,
|
| 264 |
+
dropout: float = 0.0,
|
| 265 |
+
num_classes: int = 9, # Number of emotion classes #changed by Preksha was originally 7
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
assert image_size % patch_size == 0, "Image size must be divisible by patch size"
|
| 269 |
+
self.patch_size = patch_size
|
| 270 |
+
|
| 271 |
+
self.conv_proj = nn.Conv2d(
|
| 272 |
+
in_channels = in_chans,
|
| 273 |
+
out_channels = embed_dim, #embed_dim is for the TOTAL; this is patch_dimen^2 * 3 (# of color channels)
|
| 274 |
+
kernel_size = patch_size, #this is so that the kernel is basically the patch (a square)
|
| 275 |
+
stride = patch_size #this ensures that the kernel moves so that the patches do not overlap
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
num_patches = (image_size // patch_size) ** 2
|
| 279 |
+
self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim))
|
| 280 |
+
nn.init.normal_(self.enc_pos_embed, std=0.02)
|
| 281 |
+
|
| 282 |
+
# set CLS token, a class token that contains a learnable vector that will eventually contain embeddings for the whole image
|
| 283 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 284 |
+
nn.init.normal_(self.cls_token, std = 0.02) #normal distribution
|
| 285 |
+
|
| 286 |
+
#Transformer encoder: learns contextual relationships b/t patches, generates embeddings
|
| 287 |
+
enc_layer = TransformerEncoderLayer(
|
| 288 |
+
embed_dim = embed_dim,
|
| 289 |
+
num_heads = encoder_heads, #for multihead attn
|
| 290 |
+
mlp_dim = int(embed_dim * mlp_ratio),
|
| 291 |
+
dropout = dropout #used in MLP
|
| 292 |
+
)
|
| 293 |
+
self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim) #does self attn & feed forward
|
| 294 |
+
|
| 295 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 296 |
+
self.head_norm = nn.LayerNorm(embed_dim)
|
| 297 |
+
self.head = nn.Linear(embed_dim, num_classes) # 9 emotions
|
| 298 |
+
|
| 299 |
+
def forward(self, imgs):
|
| 300 |
+
|
| 301 |
+
# 1. Patch embedding
|
| 302 |
+
x = self.conv_proj(imgs) # [B, embed_dim, H/ps, W/ps]
|
| 303 |
+
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
|
| 304 |
+
x = self.norm(x) # [B, N, embed_dim]
|
| 305 |
+
B, N, D = x.shape
|
| 306 |
+
|
| 307 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # repeat for batch size
|
| 308 |
+
x = torch.cat([cls_tokens, x], dim=1) # [B, N+1, embed_dim]
|
| 309 |
+
x = x + self.enc_pos_embed
|
| 310 |
+
|
| 311 |
+
x = self.encoder(x)
|
| 312 |
+
|
| 313 |
+
logits = self.head(self.head_norm(x[:, 0])) # Use the class token for classification
|
| 314 |
+
|
| 315 |
+
return logits
|
| 316 |
+
|
| 317 |
class ViTForEmotionClassificationMLP(ViTForEmotionClassification):
|
| 318 |
"""
|
| 319 |
Replace the linear head with MLP
|