import torch import torch.nn as nn import math def patchify(x,patch_size=8): b,c,h,w=x.shape th=h//patch_size tw=w//patch_size assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size" out=x.reshape(b,c,th,patch_size,tw,patch_size) out=out.permute(0,2,4,1,3,5).contiguous() out=out.view(b,th*tw,c*(patch_size**2)) return out def unpatchify(x,patch_size=8): b,z,p=x.shape c=p//(patch_size**2) th=int(math.sqrt(z)) tw=th h=th*patch_size w=tw*patch_size x=x.view(b,th,tw,c,patch_size,patch_size) x=x.permute(0,3,1,4,2,5).contiguous() out=x.view(b,c,h,w) return out def random_mask(x,mask_ratio=0.75): b,n,p=x.shape len_keep=int(n*(1-mask_ratio)) noise=torch.rand(b,n).to(x.device) ids_shuffle=torch.argsort(noise,dim=1) ids_restore=torch.argsort(ids_shuffle,dim=1) ids_keep=ids_shuffle[:,:len_keep] x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device) mask=torch.ones(b,n).to(x.device) mask[:,:len_keep]=0 mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device) return x_masked,mask,ids_restore,ids_keep def mae_loss(pred, target, mask): # pred/target: (B, N, P), mask: (B, N) with 1=masked B, N, P = pred.shape mask = mask.unsqueeze(-1).float() # (B, N, 1) loss = (pred - target) ** 2 loss = (loss * mask).sum() / mask.sum().clamp_min(1.0) return loss class PositionalEncoding(nn.Module): def __init__(self,num_patches,hidden_dim=768): super().__init__() self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x, visible_indices): # x: (B, len_keep, D); visible_indices: (B, len_keep) B, L, D = x.shape # expand table to (B, N, D) pos = self.pos_embed.expand(B, -1, -1) # (B, N, D) # build gather index (B, L, D) idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1)) visible_pos = torch.gather(pos, 1, idx) # (B, L, D) return x + visible_pos class TransformerBlock(nn.Module): def __init__(self,hidden_dim,mlp_dim,num_heads,dropout): super().__init__() self.layernorm1=nn.LayerNorm(hidden_dim) self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout) self.layernorm2=nn.LayerNorm(hidden_dim) self.mlp=nn.Sequential( nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout) ) def forward(self,x): residual=x x=self.layernorm1(x) attn,_=self.multihead(x,x,x) x=residual+attn residual=x x=self.layernorm2(x) x=self.mlp(x) x=residual+x return x class MAEEncoder(nn.Module): """ patch_dim-> % non-masked * no. of patches """ def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8): super().__init__() self.mask_ratio=mask_ratio self.patch_size=patch_size self.patch_embed=nn.Linear(patch_dim,hidden_dim) self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim) self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout) for _ in range(depth)]) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self,x_in): x_p=patchify(x_in,self.patch_size) x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio) x= self.patch_embed(x_masked) x=self.pos_embed(x,ids_keep) for attn_layer in self.transformer:x=attn_layer(x) return x,mask,ids_keep,ids_restore class MAEDecoder(nn.Module): def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout): super().__init__() self.num_patches=num_patches self.encoder_dim=encoder_dim self.decoder_dim=decoder_dim self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim)) self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim) self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim)) self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout) for _ in range(decoder_depth)]) self.layernorm=nn.LayerNorm(decoder_dim) self.pred=nn.Linear(decoder_dim,c*(patch_size**2)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.mask_token, std=0.02) def forward(self,x,ids_keep,ids_restore): b,n,p=x.shape xdec=self.enc_to_dec(x) len_keep=xdec.size(1) num_patches=ids_restore.size(1) num_mask=num_patches-len_keep mask_token=self.mask_token.expand(b,num_mask,-1) x_=torch.cat([xdec,mask_token],dim=1) x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1))) x_=x_+self.pos_embed for block in self.transformer:x_=block(x_) x_=self.layernorm(x_) out=self.pred(x_) return out class MaskedAutoEncoder(nn.Module): def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8): super().__init__() self.patch_size=patch_size self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2 ,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head ,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size) self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size ,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth ,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout) def forward(self,x): b,c,h,w=x.shape encoded,mask,ids_keep,ids_restore=self.encoder(x) decoded=self.decoder(encoded,ids_keep,ids_restore) xpatched=patchify(x,self.patch_size) return xpatched,decoded,mask @staticmethod def testme(): img=torch.rand(1,1,384,384) mae=MaskedAutoEncoder() a,b,c=mae(img) print(a.shape) print(b.shape) print(c.shape)