|
|
from torch.utils.data import Dataset
|
|
|
import torch.nn as nn
|
|
|
from PIL import Image
|
|
|
import json
|
|
|
import os
|
|
|
import random
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from transformer import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
|
|
|
|
|
|
|
|
|
class MAEViT(nn.Module):
|
|
|
"""
|
|
|
Masked Autoencoder (MAE) for Vision Transformer.
|
|
|
Encoder sees only a fraction of patches; decoder reconstructs all patches.
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
|
|
|
image_size: int = 224,
|
|
|
patch_size: int = 16,
|
|
|
in_chans: int = 3,
|
|
|
embed_dim: int = 768,
|
|
|
encoder_layers: int = 12,
|
|
|
encoder_heads: int = 12,
|
|
|
mlp_ratio: float = 4.0,
|
|
|
mask_ratio: float = 0.75,
|
|
|
decoder_embed_dim: int = 512,
|
|
|
decoder_layers: int = 8,
|
|
|
decoder_heads: int = 16,
|
|
|
dropout: float = 0.0,
|
|
|
):
|
|
|
super().__init__()
|
|
|
assert image_size % patch_size == 0, "Image size must be divisible by patch size"
|
|
|
self.in_chans = in_chans
|
|
|
self.image_size = image_size
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_proj = nn.Conv2d(
|
|
|
in_channels = in_chans,
|
|
|
out_channels = embed_dim,
|
|
|
kernel_size = patch_size,
|
|
|
stride = patch_size
|
|
|
)
|
|
|
num_patches = (image_size // patch_size) ** 2
|
|
|
self.mask_ratio = mask_ratio
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
nn.init.normal_(self.cls_token, std = 0.02)
|
|
|
|
|
|
|
|
|
enc_layer = TransformerEncoderLayer(
|
|
|
embed_dim = embed_dim,
|
|
|
num_heads = encoder_heads,
|
|
|
mlp_dim = int(embed_dim * mlp_ratio),
|
|
|
dropout = dropout
|
|
|
)
|
|
|
self.encoder = TransformerEncoder(enc_layer, encoder_layers, embed_dim)
|
|
|
|
|
|
|
|
|
self.enc_to_dec = nn.Linear(embed_dim, decoder_embed_dim, bias = False)
|
|
|
|
|
|
|
|
|
self.dec_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
|
|
self.dec_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, decoder_embed_dim))
|
|
|
self.enc_pos_embed = nn.Parameter(torch.empty(1, num_patches + 1, embed_dim))
|
|
|
|
|
|
nn.init.normal_(self.dec_mask_token, std=0.02)
|
|
|
nn.init.normal_(self.enc_pos_embed, std=0.02)
|
|
|
nn.init.normal_(self.dec_pos_embed, std=0.02)
|
|
|
|
|
|
|
|
|
dec_layer = TransformerDecoderLayer(
|
|
|
encoder_embed_dim = embed_dim,
|
|
|
decoder_embed_dim = decoder_embed_dim,
|
|
|
num_heads = decoder_heads,
|
|
|
mlp_dim = int(decoder_embed_dim * mlp_ratio),
|
|
|
dropout = dropout
|
|
|
)
|
|
|
self.decoder = TransformerDecoder(dec_layer, decoder_layers, embed_dim = decoder_embed_dim)
|
|
|
|
|
|
|
|
|
self.pred = nn.Linear(decoder_embed_dim, (patch_size ** 2) * in_chans)
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
|
|
|
|
def patchify(self, imgs):
|
|
|
"""
|
|
|
imgs: (B, C, H, W)
|
|
|
returns: (B, N, patch_size * patch_size * C)
|
|
|
"""
|
|
|
B, C, H, W = imgs.shape
|
|
|
p = self.patch_size
|
|
|
assert H % p == 0 and W % p == 0, "Image dimensions must be divisible by the patch size."
|
|
|
|
|
|
h = H // p
|
|
|
w = W // p
|
|
|
patches = imgs.reshape(B, C, h, p, w, p)
|
|
|
patches = torch.einsum('nchawb->nhwabc', patches)
|
|
|
patches = patches.reshape(B, h * w, p * p * C)
|
|
|
|
|
|
|
|
|
return patches
|
|
|
|
|
|
|
|
|
|
|
|
def unpatchify(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
patch_dimen = self.patch_size
|
|
|
h = int(x.shape[1]**0.5)
|
|
|
w = h
|
|
|
assert h * w == x.shape[1]
|
|
|
x = x.reshape(x.shape[0], h, w, patch_dimen, patch_dimen, self.in_chans)
|
|
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
|
|
imgs = x.reshape([x.shape[0], self.in_chans, self.image_size, self.image_size])
|
|
|
return imgs
|
|
|
|
|
|
def random_masking(self, x):
|
|
|
"""
|
|
|
Perform per-sample random masking by shuffling.
|
|
|
returns:
|
|
|
x_masked: Tensor with visible patches
|
|
|
mask: Tensor indicating which patches are visible (0) or masked (1)
|
|
|
ids_restore: Tensor to restore original order of patches
|
|
|
"""
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
len_keep = int(L*(1 - self.mask_ratio))
|
|
|
|
|
|
|
|
|
noise = torch.rand(B, L, device=x.device)
|
|
|
ids_shuffle = torch.argsort(noise, dim=1)
|
|
|
|
|
|
|
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
|
|
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep]
|
|
|
|
|
|
x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
|
|
|
|
|
|
|
|
mask = torch.ones(B, L, device=x.device)
|
|
|
|
|
|
mask[:, :len_keep] = 0
|
|
|
mask = torch.gather(mask, 1, ids_restore)
|
|
|
|
|
|
return x_masked, mask, ids_restore, ids_keep
|
|
|
|
|
|
def forward_encoder(self, imgs):
|
|
|
|
|
|
|
|
|
x = self.conv_proj(imgs)
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
x = self.norm(x)
|
|
|
B, N, D = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x + self.enc_pos_embed[:, 1:, :]
|
|
|
|
|
|
|
|
|
x_masked, mask, ids_restore, ids_keep = self.random_masking(x)
|
|
|
|
|
|
|
|
|
cls_token = self.cls_token + self.enc_pos_embed[:, :1, :]
|
|
|
cls_tokens = cls_token.expand(B, -1, -1)
|
|
|
x_enc = torch.cat([cls_tokens, x_masked], dim=1)
|
|
|
|
|
|
|
|
|
x_enc = self.encoder(x_enc)
|
|
|
|
|
|
return x_enc, mask, ids_restore
|
|
|
|
|
|
def forward_decoder(self, x_enc, ids_restore):
|
|
|
|
|
|
x_dec = self.enc_to_dec(x_enc)
|
|
|
|
|
|
|
|
|
B, L, D = x_dec.shape
|
|
|
mask_tokens = self.dec_mask_token.repeat(B, ids_restore.shape[1] + 1 - x_dec.shape[1], 1)
|
|
|
|
|
|
|
|
|
x_no_cls = torch.cat([x_dec[:, 1:, :], mask_tokens], dim=1)
|
|
|
x_no_cls = torch.gather(x_no_cls, 1, ids_restore.unsqueeze(-1).repeat(1, 1, D))
|
|
|
x_dec = torch.cat([x_dec[:, :1, :], x_no_cls], dim=1)
|
|
|
|
|
|
|
|
|
x_dec = x_dec + self.dec_pos_embed[:, :x_dec.size(1), :]
|
|
|
|
|
|
|
|
|
x_dec = self.decoder(x_dec, x_enc)
|
|
|
|
|
|
|
|
|
x_rec = self.pred(x_dec)
|
|
|
|
|
|
return x_rec
|
|
|
|
|
|
def compute_mae_loss(self, imgs, pred, mask):
|
|
|
"""
|
|
|
Mean Squared Error loss for masked patches
|
|
|
imgs: [N, 3, H, W]
|
|
|
pred: [N, L, p*p*3]
|
|
|
mask: [N, L], 0 is keep, 1 is remove,
|
|
|
"""
|
|
|
|
|
|
|
|
|
target = self.patchify(imgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = pred[:, 1:, :]
|
|
|
loss = (pred - target)**2
|
|
|
loss = loss.mean(dim=-1)
|
|
|
|
|
|
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
|
|
|
|
|
return loss
|
|
|
|
|
|
def forward(self, imgs):
|
|
|
"""
|
|
|
Forward pass for MAE: encode, decode, and compute reconstruction loss.
|
|
|
imgs: [B, 3, H, W]
|
|
|
returns: reconstruction loss
|
|
|
"""
|
|
|
|
|
|
|
|
|
x_enc, mask, ids_restore = self.forward_encoder(imgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_rec = self.forward_decoder(x_enc, ids_restore)
|
|
|
|
|
|
loss = self.compute_mae_loss(imgs, x_rec, mask)
|
|
|
|
|
|
return loss
|
|
|
|