Create mae_timm_simplified.py
Browse files- mae_timm_simplified.py +222 -0
mae_timm_simplified.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from hydra.utils import instantiate
|
| 7 |
+
from lightly.models import utils
|
| 8 |
+
# https://docs.lightly.ai/self-supervised-learning/examples/mae.html
|
| 9 |
+
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
|
| 10 |
+
from timm.models.vision_transformer import VisionTransformer
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 13 |
+
class MAE(torch.nn.Module, PyTorchModelHubMixin):
|
| 14 |
+
|
| 15 |
+
def __init__(self, cfg):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
|
| 19 |
+
|
| 20 |
+
self.patch_size = vit.patch_embed.patch_size[0]
|
| 21 |
+
|
| 22 |
+
# Get MAE backbone
|
| 23 |
+
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
|
| 24 |
+
self.sequence_length = self.backbone.sequence_length
|
| 25 |
+
|
| 26 |
+
self.encoder_dim = vit.embed_dim # for convenience later
|
| 27 |
+
|
| 28 |
+
# Get decoder
|
| 29 |
+
self.decoder = MAEDecoderTIMM(
|
| 30 |
+
num_patches=vit.patch_embed.num_patches,
|
| 31 |
+
patch_size=self.patch_size,
|
| 32 |
+
embed_dim=vit.embed_dim,
|
| 33 |
+
decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
|
| 34 |
+
decoder_depth=cfg.ssl_model.decoder.depth,
|
| 35 |
+
decoder_num_heads=cfg.ssl_model.decoder.num_heads,
|
| 36 |
+
mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
|
| 37 |
+
proj_drop_rate=cfg.ssl_model.decoder.dropout,
|
| 38 |
+
attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
|
| 39 |
+
)
|
| 40 |
+
self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
|
| 41 |
+
|
| 42 |
+
self.criterion = torch.nn.MSELoss()
|
| 43 |
+
|
| 44 |
+
def forward_encoder(self, images, idx_keep=None):
|
| 45 |
+
return self.backbone.encode(images=images, idx_keep=idx_keep)
|
| 46 |
+
|
| 47 |
+
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
|
| 48 |
+
# build decoder input
|
| 49 |
+
batch_size = x_encoded.shape[0]
|
| 50 |
+
x_decode = self.decoder.embed(x_encoded)
|
| 51 |
+
x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
|
| 52 |
+
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
|
| 53 |
+
|
| 54 |
+
# decoder forward pass
|
| 55 |
+
x_decoded = self.decoder.decode(x_masked)
|
| 56 |
+
|
| 57 |
+
# predict pixel values for masked tokens
|
| 58 |
+
x_pred = utils.get_at_index(x_decoded, idx_mask)
|
| 59 |
+
x_pred = self.decoder.predict(x_pred)
|
| 60 |
+
return x_pred
|
| 61 |
+
|
| 62 |
+
def training_step(self, batch, batch_idx):
|
| 63 |
+
images = batch["image"] # views contains only a single view
|
| 64 |
+
batch_size = images.shape[0]
|
| 65 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 66 |
+
size=(batch_size, self.sequence_length),
|
| 67 |
+
mask_ratio=self.mask_ratio,
|
| 68 |
+
device=images.device,
|
| 69 |
+
)
|
| 70 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 71 |
+
|
| 72 |
+
# decode and calculate loss (encoder no longer directly used)
|
| 73 |
+
|
| 74 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 75 |
+
|
| 76 |
+
# get image patches for masked tokens
|
| 77 |
+
patches = utils.patchify(images, self.patch_size)
|
| 78 |
+
# must adjust idx_mask for missing class token
|
| 79 |
+
# (class token was added after calculating which indices to mask,
|
| 80 |
+
# so we need to subtract 1 from idx_mask to get the new indices that are masked)
|
| 81 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 82 |
+
|
| 83 |
+
loss = self.criterion(x_pred, target)
|
| 84 |
+
|
| 85 |
+
return loss, x_encoded
|
| 86 |
+
|
| 87 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 88 |
+
images = batch["image"] # views contains only a single view
|
| 89 |
+
batch_size = images.shape[0]
|
| 90 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 91 |
+
size=(batch_size, self.sequence_length),
|
| 92 |
+
mask_ratio=self.mask_ratio,
|
| 93 |
+
device=images.device,
|
| 94 |
+
)
|
| 95 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 96 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 97 |
+
|
| 98 |
+
# get image patches for masked tokens
|
| 99 |
+
patches = utils.patchify(images, self.patch_size)
|
| 100 |
+
# must adjust idx_mask for missing class token
|
| 101 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 102 |
+
|
| 103 |
+
loss = self.criterion(x_pred, target)
|
| 104 |
+
|
| 105 |
+
return loss, None
|
| 106 |
+
|
| 107 |
+
def predict_step(self, batch, batch_idx):
|
| 108 |
+
idx_keep, idx_mask = self.mask_random_indices(batch)
|
| 109 |
+
return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
|
| 110 |
+
|
| 111 |
+
def mask_random_indices(self, batch):
|
| 112 |
+
idx_keep, idx_mask = utils.random_token_mask(
|
| 113 |
+
size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
|
| 114 |
+
mask_ratio=self.mask_ratio,
|
| 115 |
+
device=batch["image"].device,
|
| 116 |
+
)
|
| 117 |
+
return idx_keep, idx_mask
|
| 118 |
+
|
| 119 |
+
def predict(self, batch, idx_mask, idx_keep=None):
|
| 120 |
+
# not used during training etc, only as a handy API
|
| 121 |
+
# note the order of arguments is idx_mask first, as this is what most people change!
|
| 122 |
+
|
| 123 |
+
# idx 0 is the class token and is never masked
|
| 124 |
+
# user must add 1 to all indices before passing to predict! assumes this is already done
|
| 125 |
+
|
| 126 |
+
assert idx_mask is not None
|
| 127 |
+
|
| 128 |
+
if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
|
| 129 |
+
all_indices = set(range(0, self.sequence_length))
|
| 130 |
+
idx_keep = []
|
| 131 |
+
for row in idx_mask:
|
| 132 |
+
keep_row = list(all_indices - set(row.tolist()))
|
| 133 |
+
idx_keep.append(keep_row)
|
| 134 |
+
idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
|
| 135 |
+
|
| 136 |
+
images = batch["image"]
|
| 137 |
+
batch_size = images.shape[0]
|
| 138 |
+
|
| 139 |
+
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
|
| 140 |
+
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
|
| 141 |
+
|
| 142 |
+
# get masked and reconstructed images
|
| 143 |
+
im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
|
| 144 |
+
|
| 145 |
+
# calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
|
| 146 |
+
patches = utils.patchify(images, self.patch_size) # does not change batch dim
|
| 147 |
+
target = utils.get_at_index(patches, idx_mask - 1)
|
| 148 |
+
mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
|
| 149 |
+
mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
'id_str': batch['id_str'],
|
| 153 |
+
'images': image_batch_to_pil_list(images),
|
| 154 |
+
'encoded': x_encoded,
|
| 155 |
+
'masked': image_batch_to_pil_list(im_masked),
|
| 156 |
+
'reconstructed': image_batch_to_pil_list(im_reconstructed),
|
| 157 |
+
'reconstruction_error': mse_per_image
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def mask_and_reconstruct_images(self, mask, num_images, y, x):
|
| 162 |
+
im_masked = self.patchify(x) # still the original image, just reshaped
|
| 163 |
+
im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
|
| 164 |
+
|
| 165 |
+
# is mask is None, both masked and reconstructed are just the original image, do nothing
|
| 166 |
+
# otherwise
|
| 167 |
+
if mask is not None:
|
| 168 |
+
for batch_index in range(num_images):
|
| 169 |
+
# we ran out of images in the batch
|
| 170 |
+
if batch_index >= x.shape[0] or batch_index > num_images:
|
| 171 |
+
break
|
| 172 |
+
# replace values with either 0 or the predicted fill values
|
| 173 |
+
for mask_idx, token_idx in enumerate(mask[batch_index]):
|
| 174 |
+
im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
|
| 175 |
+
im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
|
| 176 |
+
|
| 177 |
+
# depatchify i.e. reshape back like original image
|
| 178 |
+
im_masked = self.unpatchify(im_masked)
|
| 179 |
+
im_reconstructed = self.unpatchify(im_reconstructed)
|
| 180 |
+
return im_masked, im_reconstructed
|
| 181 |
+
|
| 182 |
+
def unpatchify(self, x):
|
| 183 |
+
# i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
|
| 184 |
+
return einops.rearrange(
|
| 185 |
+
x,
|
| 186 |
+
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
|
| 187 |
+
p1=self.patch_size,
|
| 188 |
+
p2=self.patch_size,
|
| 189 |
+
b=x.shape[0],
|
| 190 |
+
c=3,
|
| 191 |
+
h=int(np.sqrt(x.shape[1])),
|
| 192 |
+
w=int(np.sqrt(x.shape[1])),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def patchify(self, x):
|
| 196 |
+
# confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
|
| 197 |
+
# in more normal terms
|
| 198 |
+
# x is an image shape [b, c, h, w]
|
| 199 |
+
# reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
|
| 200 |
+
return einops.rearrange(
|
| 201 |
+
x,
|
| 202 |
+
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
|
| 203 |
+
p1=self.patch_size,
|
| 204 |
+
p2=self.patch_size,
|
| 205 |
+
b=x.shape[0],
|
| 206 |
+
c=3,
|
| 207 |
+
h=x.shape[-2] // self.patch_size,
|
| 208 |
+
w=x.shape[-1] // self.patch_size,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def encoder(self):
|
| 213 |
+
return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def image_batch_to_pil_list(images):
|
| 217 |
+
images = einops.rearrange(images, 'b c h w -> b h w c')
|
| 218 |
+
images = torch.clamp(images, 0, 1)*255
|
| 219 |
+
images = images.cpu().numpy()
|
| 220 |
+
images = images.astype(np.uint8)
|
| 221 |
+
# print(images.shape)
|
| 222 |
+
return [Image.fromarray(im) for im in images]
|