File size: 9,343 Bytes
d88e92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225


from PIL import Image

import einops
import numpy as np
import torch
from hydra.utils import instantiate
from lightly.models import utils
# https://docs.lightly.ai/self-supervised-learning/examples/mae.html
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from timm.models.vision_transformer import VisionTransformer

from huggingface_hub import PyTorchModelHubMixin
class MAE(torch.nn.Module, PyTorchModelHubMixin):

    def __init__(self, cfg):
        super().__init__()
        
        vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size)

        self.patch_size = vit.patch_embed.patch_size[0]

        # Get MAE backbone
        self.backbone = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_length = self.backbone.sequence_length

        self.encoder_dim = vit.embed_dim  # for convenience later

        # Get decoder
        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=cfg.ssl_model.decoder.embed_dim,
            decoder_depth=cfg.ssl_model.decoder.depth,
            decoder_num_heads=cfg.ssl_model.decoder.num_heads,
            mlp_ratio=cfg.ssl_model.decoder.mlp_ratio,
            proj_drop_rate=cfg.ssl_model.decoder.dropout,
            attn_drop_rate=cfg.ssl_model.decoder.attention_dropout,
        )
        self.mask_ratio = cfg.ssl_model.mask_ratio  # saved as model parameter, not aug, since it is applied within model

        self.criterion = torch.nn.MSELoss()

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length))
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def training_step(self, batch, batch_idx):
        images = batch["image"]  # views contains only a single view
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)

        # decode and calculate loss (encoder no longer directly used)

        x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        # (class token was added after calculating which indices to mask, 
        # so we need to subtract 1 from idx_mask to get the new indices that are masked)
        target = utils.get_at_index(patches, idx_mask - 1)

        loss = self.criterion(x_pred, target)

        return loss, x_encoded

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        images = batch["image"]  # views contains only a single view
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)

        loss = self.criterion(x_pred, target)

        return loss, None

    def predict_step(self, batch, batch_idx): 
        idx_keep, idx_mask = self.mask_random_indices(batch)
        return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep)

    def mask_random_indices(self, batch):
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch["image"].shape[0], self.sequence_length),  # (batch_size, seq_len)
            mask_ratio=self.mask_ratio,
            device=batch["image"].device,
        )
        return idx_keep, idx_mask

    def predict(self, batch, idx_mask, idx_keep=None):  
        # not used during training etc, only as a handy API
        # note the order of arguments is idx_mask first, as this is what most people change!

        # idx 0 is the class token and is never masked
        # user must add 1 to all indices before passing to predict! assumes this is already done
        
        assert idx_mask is not None

        if idx_keep is None:  # probably a user only providing idx_mask, not using predict_step above
            all_indices = set(range(0, self.sequence_length))
            idx_keep = []
            for row in idx_mask:
                keep_row = list(all_indices - set(row.tolist()))
                idx_keep.append(keep_row)
            idx_keep = torch.tensor(idx_keep).to(idx_mask.device)

        images = batch["image"]
        batch_size = images.shape[0]

        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask)

        # get masked and reconstructed images
        im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images)

        # calculate MSE (copied from above, but with per-image reduction not per-batch reduction)
        patches = utils.patchify(images, self.patch_size)  # does not change batch dim
        target = utils.get_at_index(patches, idx_mask - 1)
        mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target)
        mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1)  # reduce all dimensions but batch

        return {
            'id_str': batch['id_str'],
            'images': image_batch_to_pil_list(images),
            'encoded': x_encoded,
            'masked': image_batch_to_pil_list(im_masked),
            'reconstructed': image_batch_to_pil_list(im_reconstructed),
            'reconstruction_error': mse_per_image
        }
    

    def mask_and_reconstruct_images(self, mask, num_images, y, x):
        im_masked = self.patchify(x)  # still the original image, just reshaped
        im_reconstructed = im_masked.clone()  # same for now, but will become the reconstructed images

        # is mask is None, both masked and reconstructed are just the original image, do nothing
        # otherwise
        if mask is not None:
            for batch_index in range(num_images):
                # we ran out of images in the batch
                if batch_index >= x.shape[0] or batch_index > num_images:
                    break
                # replace values with either 0 or the predicted fill values
                for mask_idx, token_idx in enumerate(mask[batch_index]):
                    im_masked[batch_index, token_idx - 1] = 0  # set masked pixels to 0
                    im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :]  # set masked pixels to predicted pixels

        # depatchify i.e. reshape back like original image
        im_masked = self.unpatchify(im_masked)
        im_reconstructed = self.unpatchify(im_reconstructed)
        return im_masked, im_reconstructed

    def unpatchify(self, x):
        # i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size
        return einops.rearrange(
            x,
            "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
            p1=self.patch_size,
            p2=self.patch_size,
            b=x.shape[0],
            c=3,
            h=int(np.sqrt(x.shape[1])),
            w=int(np.sqrt(x.shape[1])),
        )

    def patchify(self, x):
        # confusingly, "h" here is height // patch size i.e. number of patches and p is patch size
        # in more normal terms
        # x is an image shape [b, c, h, w]
        # reshape to [b, n_patches^2/patch_size^2, patch_size^2*c]
        return einops.rearrange(
            x,
            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch_size,
            p2=self.patch_size,
            b=x.shape[0],
            c=3,
            h=x.shape[-2] // self.patch_size,
            w=x.shape[-1] // self.patch_size,
        )

    @property
    def encoder(self):
        return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all)


def image_batch_to_pil_list(images):
    images = einops.rearrange(images, 'b c h w -> b h w c')
    images = torch.clamp(images, 0, 1)*255
    images = images.cpu().numpy()
    images = images.astype(np.uint8)
    # print(images.shape)
    return [Image.fromarray(im) for im in images]