from torch.utils.data import Dataset import torch.nn as nn from PIL import Image import json import os import random import torch import numpy as np from transformer import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer #includes both MAE and Vision Transformer for pretraining class MAEViT(nn.Module): """ Masked Autoencoder (MAE) for Vision Transformer. Encoder sees only a fraction of patches; decoder reconstructs all patches. """ def __init__( self, # default values for ViT-B-16 image_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, encoder_layers: int = 12, encoder_heads: int = 12, mlp_ratio: float = 4.0, mask_ratio: float = 0.75, decoder_embed_dim: int = 512, decoder_layers: int = 8, decoder_heads: int = 16, dropout: float = 0.0, ): super().__init__() assert image_size % patch_size == 0, "Image size must be divisible by patch size" self.in_chans = in_chans self.image_size = image_size self.patch_size = patch_size #Conv2d trick to PATCHIFY AND EMBED (DIFFERENT FROM THE PATCHIFY Function #which is used in validation) self.conv_proj = nn.Conv2d( in_channels = in_chans, out_channels = embed_dim, #embed_dim is for the TOTAL; this is patch_dimen^2 * 3 (# of color channels) kernel_size = patch_size, #this is so that the kernel is basically the patch (a square) stride = patch_size #this ensures that the kernel moves so that the patches do not overlap ) num_patches = (image_size // patch_size) ** 2 #just the number of patches since image_size // patch_size deals with only the dimension self.mask_ratio = mask_ratio #75% is masked for best results with MAE #set CLS token, a class token that contains a learnable vector that will eventually contain embeddings for the whole image self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.normal_(self.cls_token, std = 0.02) #normal distribution #Transformer encoder: learns contextual relationships b/t patches, generates embeddings enc_layer = TransformerEncoderLayer( embed_dim = embed_dim, num_heads = encoder_heads, #for multihead attn mlp_dim = int(embed_dim * mlp_ratio), dropout = dropout #used in MLP ) self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim) #does self attn & feed forward #Encoder -> Decoder (Linear Projection) self.enc_to_dec = nn.Linear(embed_dim, decoder_embed_dim, bias = False) #Decoder mask token (learnable placeholder token at each masked patch, helps decoder reconstruct those patches) and positional embedding generated self.dec_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) #mask tokens originally set to zero self.dec_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, decoder_embed_dim)) #num_patches + 1 includes cls token self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim)) nn.init.normal_(self.dec_mask_token, std=0.02) nn.init.normal_(self.enc_pos_embed, std=0.02) nn.init.normal_(self.dec_pos_embed, std=0.02) #Decoder for the transformer: predicts the masked patches dec_layer = TransformerDecoderLayer( encoder_embed_dim = embed_dim, decoder_embed_dim = decoder_embed_dim, num_heads = decoder_heads, mlp_dim = int(decoder_embed_dim * mlp_ratio), dropout = dropout #a regularizer to prevent model from overfitting and possibly making decisions based on noise ) self.decoder = TransformerDecoder(dec_layer, decoder_layers, embed_dim = decoder_embed_dim) #Reconstruction for masked patches self.pred = nn.Linear(decoder_embed_dim, (patch_size ** 2) * in_chans) self.norm = nn.LayerNorm(embed_dim) #patchify: converts image into tensors for patches def patchify(self, imgs): """ imgs: (B, C, H, W) returns: (B, N, patch_size * patch_size * C) """ B, C, H, W = imgs.shape p = self.patch_size assert H % p == 0 and W % p == 0, "Image dimensions must be divisible by the patch size." h = H // p w = W // p patches = imgs.reshape(B, C, h, p, w, p) patches = torch.einsum('nchawb->nhwabc', patches) patches = patches.reshape(B, h * w, p * p * C) #print("size of patches: ") #print(patches.size()) return patches #unpatchify: helps reconstruct image from patches (tensors -> images) #is not actually needed, maybe for debugging def unpatchify(self, x): #x is a tensor of shape (B, num_patches, patch_size*patch_size*in_chans) #x represents flattened pixel values #imgs: (returned) has shape (B, in_chans, img_size, img_size) patch_dimen = self.patch_size h = int(x.shape[1]**0.5) w = h assert h * w == x.shape[1] x = x.reshape(x.shape[0], h, w, patch_dimen, patch_dimen, self.in_chans) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape([x.shape[0], self.in_chans, self.image_size, self.image_size]) return imgs def random_masking(self, x): """ Perform per-sample random masking by shuffling. returns: x_masked: Tensor with visible patches mask: Tensor indicating which patches are visible (0) or masked (1) ids_restore: Tensor to restore original order of patches """ B, L, D = x.shape #number of patches to keep len_keep = int(L*(1 - self.mask_ratio)) #indices for visible patches by generating noise noise = torch.rand(B, L, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) #restore indices for unshuffling patches ids_restore = torch.argsort(ids_shuffle, dim=1) #indices of kept patches ids_keep = ids_shuffle[:, :len_keep] #visible patches gathered x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D)) #binary mask for patch embedding (1 is for masked, 0 is for visible) mask = torch.ones(B, L, device=x.device) #mask is unshuffled back into original patch order mask[:, :len_keep] = 0 ## DONGHEE: THIS PART WAS MISSING IN THE ORIGINAL CODE mask = torch.gather(mask, 1, ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_encoder(self, imgs): # 1. Patch embedding x = self.conv_proj(imgs) # [B, embed_dim, H/ps, W/ps] x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim] x = self.norm(x) # [B, N, embed_dim] B, N, D = x.shape # 2. Add positional embeddings (w/o class token) #print(x.shape) #print(self.enc_pos_embed.shape) x = x + self.enc_pos_embed[:, 1:, :] # 3. Random masking x_masked, mask, ids_restore, ids_keep = self.random_masking(x) # 4. Encoder input (cls token + visible patches) cls_token = self.cls_token + self.enc_pos_embed[:, :1, :] # class token with positional embedding cls_tokens = cls_token.expand(B, -1, -1) # repeat for batch size x_enc = torch.cat([cls_tokens, x_masked], dim=1) # 5. Encoder forward x_enc = self.encoder(x_enc) # TO DO return x_enc, mask, ids_restore def forward_decoder(self, x_enc, ids_restore): # encoder output needs to be projected to decoder embedding space x_dec = self.enc_to_dec(x_enc) #sequence unshuffled to original order B, L, D = x_dec.shape mask_tokens = self.dec_mask_token.repeat(B, ids_restore.shape[1] + 1 - x_dec.shape[1], 1) #concatenate output from heads? x_no_cls = torch.cat([x_dec[:, 1:, :], mask_tokens], dim=1) x_no_cls = torch.gather(x_no_cls, 1, ids_restore.unsqueeze(-1).repeat(1, 1, D)) x_dec = torch.cat([x_dec[:, :1, :], x_no_cls], dim=1) #add positional embeddings x_dec = x_dec + self.dec_pos_embed[:, :x_dec.size(1), :] #decoder forward x_dec = self.decoder(x_dec, x_enc) #predict pixels (without class token) x_rec = self.pred(x_dec) return x_rec def compute_mae_loss(self, imgs, pred, mask): """ Mean Squared Error loss for masked patches imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ #mask: binary mask tensor target = self.patchify(imgs) #print("target size: ") #print(target.size()) #print("pred size: ") #print(pred.size()) pred = pred[:, 1:, :] loss = (pred - target)**2 loss = loss.mean(dim=-1) #we don't want to calculate loss on visible patches, only masked patches loss = (loss * mask).sum() / (mask.sum() + 1e-6) return loss def forward(self, imgs): """ Forward pass for MAE: encode, decode, and compute reconstruction loss. imgs: [B, 3, H, W] returns: reconstruction loss """ # 1. Forward encoder x_enc, mask, ids_restore = self.forward_encoder(imgs) #x_enc = self.enc_to_dec(x_enc) # 2. Forward decoder x_rec = self.forward_decoder(x_enc, ids_restore) loss = self.compute_mae_loss(imgs, x_rec, mask) return loss