mwalmsley commited on
Commit
7c6ba0f
·
verified ·
1 Parent(s): 568c582

Create mae_timm_simplified.py

Browse files
Files changed (1) hide show
  1. mae_timm_simplified.py +222 -0
mae_timm_simplified.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ from hydra.utils import instantiate
7
+ from lightly.models import utils
8
+ # https://docs.lightly.ai/self-supervised-learning/examples/mae.html
9
+ from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
10
+ from timm.models.vision_transformer import VisionTransformer
11
+
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ class MAE(torch.nn.Module, PyTorchModelHubMixin):
14
+
15
+ def __init__(self, cfg):
16
+ super().__init__()
17
+
18
+ vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)
19
+
20
+ self.patch_size = vit.patch_embed.patch_size[0]
21
+
22
+ # Get MAE backbone
23
+ self.backbone = MaskedVisionTransformerTIMM(vit=vit)
24
+ self.sequence_length = self.backbone.sequence_length
25
+
26
+ self.encoder_dim = vit.embed_dim # for convenience later
27
+
28
+ # Get decoder
29
+ self.decoder = MAEDecoderTIMM(
30
+ num_patches=vit.patch_embed.num_patches,
31
+ patch_size=self.patch_size,
32
+ embed_dim=vit.embed_dim,
33
+ decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
34
+ decoder_depth=cfg.ssl_model.decoder.depth,
35
+ decoder_num_heads=cfg.ssl_model.decoder.num_heads,
36
+ mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
37
+ proj_drop_rate=cfg.ssl_model.decoder.dropout,
38
+ attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
39
+ )
40
+ self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model
41
+
42
+ self.criterion = torch.nn.MSELoss()
43
+
44
+ def forward_encoder(self, images, idx_keep=None):
45
+ return self.backbone.encode(images=images, idx_keep=idx_keep)
46
+
47
+ def forward_decoder(self, x_encoded, idx_keep, idx_mask):
48
+ # build decoder input
49
+ batch_size = x_encoded.shape[0]
50
+ x_decode = self.decoder.embed(x_encoded)
51
+ x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
52
+ x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))
53
+
54
+ # decoder forward pass
55
+ x_decoded = self.decoder.decode(x_masked)
56
+
57
+ # predict pixel values for masked tokens
58
+ x_pred = utils.get_at_index(x_decoded, idx_mask)
59
+ x_pred = self.decoder.predict(x_pred)
60
+ return x_pred
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ images = batch["image"] # views contains only a single view
64
+ batch_size = images.shape[0]
65
+ idx_keep, idx_mask = utils.random_token_mask(
66
+ size=(batch_size, self.sequence_length),
67
+ mask_ratio=self.mask_ratio,
68
+ device=images.device,
69
+ )
70
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
71
+
72
+ # decode and calculate loss (encoder no longer directly used)
73
+
74
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
75
+
76
+ # get image patches for masked tokens
77
+ patches = utils.patchify(images, self.patch_size)
78
+ # must adjust idx_mask for missing class token
79
+ # (class token was added after calculating which indices to mask,
80
+ # so we need to subtract 1 from idx_mask to get the new indices that are masked)
81
+ target = utils.get_at_index(patches, idx_mask - 1)
82
+
83
+ loss = self.criterion(x_pred, target)
84
+
85
+ return loss, x_encoded
86
+
87
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
88
+ images = batch["image"] # views contains only a single view
89
+ batch_size = images.shape[0]
90
+ idx_keep, idx_mask = utils.random_token_mask(
91
+ size=(batch_size, self.sequence_length),
92
+ mask_ratio=self.mask_ratio,
93
+ device=images.device,
94
+ )
95
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
96
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
97
+
98
+ # get image patches for masked tokens
99
+ patches = utils.patchify(images, self.patch_size)
100
+ # must adjust idx_mask for missing class token
101
+ target = utils.get_at_index(patches, idx_mask - 1)
102
+
103
+ loss = self.criterion(x_pred, target)
104
+
105
+ return loss, None
106
+
107
+ def predict_step(self, batch, batch_idx):
108
+ idx_keep, idx_mask = self.mask_random_indices(batch)
109
+ return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)
110
+
111
+ def mask_random_indices(self, batch):
112
+ idx_keep, idx_mask = utils.random_token_mask(
113
+ size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len)
114
+ mask_ratio=self.mask_ratio,
115
+ device=batch["image"].device,
116
+ )
117
+ return idx_keep, idx_mask
118
+
119
+ def predict(self, batch, idx_mask, idx_keep=None):
120
+ # not used during training etc, only as a handy API
121
+ # note the order of arguments is idx_mask first, as this is what most people change!
122
+
123
+ # idx 0 is the class token and is never masked
124
+ # user must add 1 to all indices before passing to predict! assumes this is already done
125
+
126
+ assert idx_mask is not None
127
+
128
+ if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above
129
+ all_indices = set(range(0, self.sequence_length))
130
+ idx_keep = []
131
+ for row in idx_mask:
132
+ keep_row = list(all_indices - set(row.tolist()))
133
+ idx_keep.append(keep_row)
134
+ idx_keep = torch.tensor(idx_keep).to(idx_mask.device)
135
+
136
+ images = batch["image"]
137
+ batch_size = images.shape[0]
138
+
139
+ x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
140
+ x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)
141
+
142
+ # get masked and reconstructed images
143
+ im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)
144
+
145
+ # calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
146
+ patches = utils.patchify(images, self.patch_size) # does not change batch dim
147
+ target = utils.get_at_index(patches, idx_mask - 1)
148
+ mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
149
+ mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch
150
+
151
+ return {
152
+ 'id_str': batch['id_str'],
153
+ 'images': image_batch_to_pil_list(images),
154
+ 'encoded': x_encoded,
155
+ 'masked': image_batch_to_pil_list(im_masked),
156
+ 'reconstructed': image_batch_to_pil_list(im_reconstructed),
157
+ 'reconstruction_error': mse_per_image
158
+ }
159
+
160
+
161
+ def mask_and_reconstruct_images(self, mask, num_images, y, x):
162
+ im_masked = self.patchify(x) # still the original image, just reshaped
163
+ im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images
164
+
165
+ # is mask is None, both masked and reconstructed are just the original image, do nothing
166
+ # otherwise
167
+ if mask is not None:
168
+ for batch_index in range(num_images):
169
+ # we ran out of images in the batch
170
+ if batch_index >= x.shape[0] or batch_index > num_images:
171
+ break
172
+ # replace values with either 0 or the predicted fill values
173
+ for mask_idx, token_idx in enumerate(mask[batch_index]):
174
+ im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0
175
+ im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels
176
+
177
+ # depatchify i.e. reshape back like original image
178
+ im_masked = self.unpatchify(im_masked)
179
+ im_reconstructed = self.unpatchify(im_reconstructed)
180
+ return im_masked, im_reconstructed
181
+
182
+ def unpatchify(self, x):
183
+ # i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
184
+ return einops.rearrange(
185
+ x,
186
+ "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
187
+ p1=self.patch_size,
188
+ p2=self.patch_size,
189
+ b=x.shape[0],
190
+ c=3,
191
+ h=int(np.sqrt(x.shape[1])),
192
+ w=int(np.sqrt(x.shape[1])),
193
+ )
194
+
195
+ def patchify(self, x):
196
+ # confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
197
+ # in more normal terms
198
+ # x is an image shape [b, c, h, w]
199
+ # reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
200
+ return einops.rearrange(
201
+ x,
202
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
203
+ p1=self.patch_size,
204
+ p2=self.patch_size,
205
+ b=x.shape[0],
206
+ c=3,
207
+ h=x.shape[-2] // self.patch_size,
208
+ w=x.shape[-1] // self.patch_size,
209
+ )
210
+
211
+ @property
212
+ def encoder(self):
213
+ return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)
214
+
215
+
216
+ def image_batch_to_pil_list(images):
217
+ images = einops.rearrange(images, 'b c h w -> b h w c')
218
+ images = torch.clamp(images, 0, 1)*255
219
+ images = images.cpu().numpy()
220
+ images = images.astype(np.uint8)
221
+ # print(images.shape)
222
+ return [Image.fromarray(im) for im in images]