prekshyam's picture
Emotion Classifier, MAE/ViT architecture uploaded
9d79189 verified
from torch.utils.data import Dataset
import torch.nn as nn
from PIL import Image
import json
import os
import random
import torch
import numpy as np
from transformer import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
#includes both MAE and Vision Transformer for pretraining
class MAEViT(nn.Module):
"""
Masked Autoencoder (MAE) for Vision Transformer.
Encoder sees only a fraction of patches; decoder reconstructs all patches.
"""
def __init__(
self,
# default values for ViT-B-16
image_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
encoder_layers: int = 12,
encoder_heads: int = 12,
mlp_ratio: float = 4.0,
mask_ratio: float = 0.75,
decoder_embed_dim: int = 512,
decoder_layers: int = 8,
decoder_heads: int = 16,
dropout: float = 0.0,
):
super().__init__()
assert image_size % patch_size == 0, "Image size must be divisible by patch size"
self.in_chans = in_chans
self.image_size = image_size
self.patch_size = patch_size
#Conv2d trick to PATCHIFY AND EMBED (DIFFERENT FROM THE PATCHIFY Function
#which is used in validation)
self.conv_proj = nn.Conv2d(
in_channels = in_chans,
out_channels = embed_dim, #embed_dim is for the TOTAL; this is patch_dimen^2 * 3 (# of color channels)
kernel_size = patch_size, #this is so that the kernel is basically the patch (a square)
stride = patch_size #this ensures that the kernel moves so that the patches do not overlap
)
num_patches = (image_size // patch_size) ** 2 #just the number of patches since image_size // patch_size deals with only the dimension
self.mask_ratio = mask_ratio #75% is masked for best results with MAE
#set CLS token, a class token that contains a learnable vector that will eventually contain embeddings for the whole image
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.normal_(self.cls_token, std = 0.02) #normal distribution
#Transformer encoder: learns contextual relationships b/t patches, generates embeddings
enc_layer = TransformerEncoderLayer(
embed_dim = embed_dim,
num_heads = encoder_heads, #for multihead attn
mlp_dim = int(embed_dim * mlp_ratio),
dropout = dropout #used in MLP
)
self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim) #does self attn & feed forward
#Encoder -> Decoder (Linear Projection)
self.enc_to_dec = nn.Linear(embed_dim, decoder_embed_dim, bias = False)
#Decoder mask token (learnable placeholder token at each masked patch, helps decoder reconstruct those patches) and positional embedding generated
self.dec_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) #mask tokens originally set to zero
self.dec_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, decoder_embed_dim)) #num_patches + 1 includes cls token
self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim))
nn.init.normal_(self.dec_mask_token, std=0.02)
nn.init.normal_(self.enc_pos_embed, std=0.02)
nn.init.normal_(self.dec_pos_embed, std=0.02)
#Decoder for the transformer: predicts the masked patches
dec_layer = TransformerDecoderLayer(
encoder_embed_dim = embed_dim,
decoder_embed_dim = decoder_embed_dim,
num_heads = decoder_heads,
mlp_dim = int(decoder_embed_dim * mlp_ratio),
dropout = dropout #a regularizer to prevent model from overfitting and possibly making decisions based on noise
)
self.decoder = TransformerDecoder(dec_layer, decoder_layers, embed_dim = decoder_embed_dim)
#Reconstruction for masked patches
self.pred = nn.Linear(decoder_embed_dim, (patch_size ** 2) * in_chans)
self.norm = nn.LayerNorm(embed_dim)
#patchify: converts image into tensors for patches
def patchify(self, imgs):
"""
imgs: (B, C, H, W)
returns: (B, N, patch_size * patch_size * C)
"""
B, C, H, W = imgs.shape
p = self.patch_size
assert H % p == 0 and W % p == 0, "Image dimensions must be divisible by the patch size."
h = H // p
w = W // p
patches = imgs.reshape(B, C, h, p, w, p)
patches = torch.einsum('nchawb->nhwabc', patches)
patches = patches.reshape(B, h * w, p * p * C)
#print("size of patches: ")
#print(patches.size())
return patches
#unpatchify: helps reconstruct image from patches (tensors -> images)
#is not actually needed, maybe for debugging
def unpatchify(self, x):
#x is a tensor of shape (B, num_patches, patch_size*patch_size*in_chans)
#x represents flattened pixel values
#imgs: (returned) has shape (B, in_chans, img_size, img_size)
patch_dimen = self.patch_size
h = int(x.shape[1]**0.5)
w = h
assert h * w == x.shape[1]
x = x.reshape(x.shape[0], h, w, patch_dimen, patch_dimen, self.in_chans)
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape([x.shape[0], self.in_chans, self.image_size, self.image_size])
return imgs
def random_masking(self, x):
"""
Perform per-sample random masking by shuffling.
returns:
x_masked: Tensor with visible patches
mask: Tensor indicating which patches are visible (0) or masked (1)
ids_restore: Tensor to restore original order of patches
"""
B, L, D = x.shape
#number of patches to keep
len_keep = int(L*(1 - self.mask_ratio))
#indices for visible patches by generating noise
noise = torch.rand(B, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
#restore indices for unshuffling patches
ids_restore = torch.argsort(ids_shuffle, dim=1)
#indices of kept patches
ids_keep = ids_shuffle[:, :len_keep]
#visible patches gathered
x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))
#binary mask for patch embedding (1 is for masked, 0 is for visible)
mask = torch.ones(B, L, device=x.device)
#mask is unshuffled back into original patch order
mask[:, :len_keep] = 0 ## DONGHEE: THIS PART WAS MISSING IN THE ORIGINAL CODE
mask = torch.gather(mask, 1, ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_encoder(self, imgs):
# 1. Patch embedding
x = self.conv_proj(imgs) # [B, embed_dim, H/ps, W/ps]
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
x = self.norm(x) # [B, N, embed_dim]
B, N, D = x.shape
# 2. Add positional embeddings (w/o class token)
#print(x.shape)
#print(self.enc_pos_embed.shape)
x = x + self.enc_pos_embed[:, 1:, :]
# 3. Random masking
x_masked, mask, ids_restore, ids_keep = self.random_masking(x)
# 4. Encoder input (cls token + visible patches)
cls_token = self.cls_token + self.enc_pos_embed[:, :1, :] # class token with positional embedding
cls_tokens = cls_token.expand(B, -1, -1) # repeat for batch size
x_enc = torch.cat([cls_tokens, x_masked], dim=1)
# 5. Encoder forward
x_enc = self.encoder(x_enc) # TO DO
return x_enc, mask, ids_restore
def forward_decoder(self, x_enc, ids_restore):
# encoder output needs to be projected to decoder embedding space
x_dec = self.enc_to_dec(x_enc)
#sequence unshuffled to original order
B, L, D = x_dec.shape
mask_tokens = self.dec_mask_token.repeat(B, ids_restore.shape[1] + 1 - x_dec.shape[1], 1)
#concatenate output from heads?
x_no_cls = torch.cat([x_dec[:, 1:, :], mask_tokens], dim=1)
x_no_cls = torch.gather(x_no_cls, 1, ids_restore.unsqueeze(-1).repeat(1, 1, D))
x_dec = torch.cat([x_dec[:, :1, :], x_no_cls], dim=1)
#add positional embeddings
x_dec = x_dec + self.dec_pos_embed[:, :x_dec.size(1), :]
#decoder forward
x_dec = self.decoder(x_dec, x_enc)
#predict pixels (without class token)
x_rec = self.pred(x_dec)
return x_rec
def compute_mae_loss(self, imgs, pred, mask):
"""
Mean Squared Error loss for masked patches
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
#mask: binary mask tensor
target = self.patchify(imgs)
#print("target size: ")
#print(target.size())
#print("pred size: ")
#print(pred.size())
pred = pred[:, 1:, :]
loss = (pred - target)**2
loss = loss.mean(dim=-1)
#we don't want to calculate loss on visible patches, only masked patches
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
return loss
def forward(self, imgs):
"""
Forward pass for MAE: encode, decode, and compute reconstruction loss.
imgs: [B, 3, H, W]
returns: reconstruction loss
"""
# 1. Forward encoder
x_enc, mask, ids_restore = self.forward_encoder(imgs)
#x_enc = self.enc_to_dec(x_enc)
# 2. Forward decoder
x_rec = self.forward_decoder(x_enc, ids_restore)
loss = self.compute_mae_loss(imgs, x_rec, mask)
return loss