Spaces:
Sleeping
Sleeping
File size: 9,343 Bytes
d88e92f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
from PIL import Image
import einops
import numpy as np
import torch
from hydra.utils import instantiate
from lightly.models import utils
# https://docs.lightly.ai/self-supervised-learning/examples/mae.html
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from timm.models.vision_transformer import VisionTransformer
from huggingface_hub import PyTorchModelHubMixin
class MAE(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, cfg):
super().__init__()
vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
self.patch_size = vit.patch_embed.patch_size[0]
# Get MAE backbone
self.backbone = MaskedVisionTransformerTIMM(vit=vit)
self.sequence_length = self.backbone.sequence_length
self.encoder_dim = vit.embed_dim # for convenience later
# Get decoder
self.decoder = MAEDecoderTIMM(
num_patches=vit.patch_embed.num_patches,
patch_size=self.patch_size,
embed_dim=vit.embed_dim,
decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
decoder_depth=cfg.ssl_model.decoder.depth,
decoder_num_heads=cfg.ssl_model.decoder.num_heads,
mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
proj_drop_rate=cfg.ssl_model.decoder.dropout,
attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
)
self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
self.criterion = torch.nn.MSELoss()
def forward_encoder(self, images, idx_keep=None):
return self.backbone.encode(images=images, idx_keep=idx_keep)
def forward_decoder(self, x_encoded, idx_keep, idx_mask):
# build decoder input
batch_size = x_encoded.shape[0]
x_decode = self.decoder.embed(x_encoded)
x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
# decoder forward pass
x_decoded = self.decoder.decode(x_masked)
# predict pixel values for masked tokens
x_pred = utils.get_at_index(x_decoded, idx_mask)
x_pred = self.decoder.predict(x_pred)
return x_pred
def training_step(self, batch, batch_idx):
images = batch["image"] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
# decode and calculate loss (encoder no longer directly used)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
# (class token was added after calculating which indices to mask,
# so we need to subtract 1 from idx_mask to get the new indices that are masked)
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_pred, target)
return loss, x_encoded
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch["image"] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_pred, target)
return loss, None
def predict_step(self, batch, batch_idx):
idx_keep, idx_mask = self.mask_random_indices(batch)
return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
def mask_random_indices(self, batch):
idx_keep, idx_mask = utils.random_token_mask(
size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
mask_ratio=self.mask_ratio,
device=batch["image"].device,
)
return idx_keep, idx_mask
def predict(self, batch, idx_mask, idx_keep=None):
# not used during training etc, only as a handy API
# note the order of arguments is idx_mask first, as this is what most people change!
# idx 0 is the class token and is never masked
# user must add 1 to all indices before passing to predict! assumes this is already done
assert idx_mask is not None
if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
all_indices = set(range(0, self.sequence_length))
idx_keep = []
for row in idx_mask:
keep_row = list(all_indices - set(row.tolist()))
idx_keep.append(keep_row)
idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
images = batch["image"]
batch_size = images.shape[0]
x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
# get masked and reconstructed images
im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
# calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
patches = utils.patchify(images, self.patch_size) # does not change batch dim
target = utils.get_at_index(patches, idx_mask - 1)
mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
return {
'id_str': batch['id_str'],
'images': image_batch_to_pil_list(images),
'encoded': x_encoded,
'masked': image_batch_to_pil_list(im_masked),
'reconstructed': image_batch_to_pil_list(im_reconstructed),
'reconstruction_error': mse_per_image
}
def mask_and_reconstruct_images(self, mask, num_images, y, x):
im_masked = self.patchify(x) # still the original image, just reshaped
im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
# is mask is None, both masked and reconstructed are just the original image, do nothing
# otherwise
if mask is not None:
for batch_index in range(num_images):
# we ran out of images in the batch
if batch_index >= x.shape[0] or batch_index > num_images:
break
# replace values with either 0 or the predicted fill values
for mask_idx, token_idx in enumerate(mask[batch_index]):
im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
# depatchify i.e. reshape back like original image
im_masked = self.unpatchify(im_masked)
im_reconstructed = self.unpatchify(im_reconstructed)
return im_masked, im_reconstructed
def unpatchify(self, x):
# i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
return einops.rearrange(
x,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
b=x.shape[0],
c=3,
h=int(np.sqrt(x.shape[1])),
w=int(np.sqrt(x.shape[1])),
)
def patchify(self, x):
# confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
# in more normal terms
# x is an image shape [b, c, h, w]
# reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
return einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_size,
p2=self.patch_size,
b=x.shape[0],
c=3,
h=x.shape[-2] // self.patch_size,
w=x.shape[-1] // self.patch_size,
)
@property
def encoder(self):
return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
def image_batch_to_pil_list(images):
images = einops.rearrange(images, 'b c h w -> b h w c')
images = torch.clamp(images, 0, 1)*255
images = images.cpu().numpy()
images = images.astype(np.uint8)
# print(images.shape)
return [Image.fromarray(im) for im in images]
|