File size: 35,469 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
"""
SAE Model    Script  ver: Oct 28th 2023 15:30
SAE stands for shuffled autoencoder, designed for PuzzleTuning

# References:
Based on MAE code.
https://github.com/facebookresearch/mae

"""

from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block

from SSL_structures.pos_embed import get_2d_sincos_pos_embed

from Backbone.VPT_structure import VPT_ViT


class ShuffledAutoEncoderViT(VPT_ViT):
    """
    Shuffled Autoencoder with VisionTransformer backbone

    prompt_mode: "Deep" / "Shallow"  by default None
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, group_shuffle_size=-1,
                 prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None):

        if prompt_mode is None:
            super().__init__()
            # SAE encoder specifics (this part just the same as ViT)
            # --------------------------------------------------------------------------
            self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)  # BCHW -> BNC
            num_patches = self.patch_embed.num_patches

            # learnable cls token is still used but on cls head need
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            # set and freeze encoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
            # Encoder blocks
            self.blocks = nn.ModuleList([  # qk_scale=None fixme related to timm version
                Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
                for i in range(depth)])
            self.norm = norm_layer(embed_dim)

            self.prompt_mode = prompt_mode
            # --------------------------------------------------------------------------

        else:
            super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
                             embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode,
                             basic_state_dict=None)  # Firstly, set then Encoder state_dict to none here.
            num_patches = self.patch_embed.num_patches  # set patch_embed of VPT
            # set and freeze encoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)

            self.prompt_mode = prompt_mode
            # Freeze Encoder parameters except of the Prompt Tokens
            self.Freeze()

        # SAE decoder specifics todo as a low-level backbone, the explore for future segmentation is need
        # --------------------------------------------------------------------------
        # if the feature dimension of encoder and decoder are different, use decoder_embed to align them
        if embed_dim != decoder_embed_dim:
            self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        else:
            self.decoder_embed = nn.Identity()

        # set decoder
        if decoder is not None:
            self.decoder = decoder
            # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
            self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch

        else:
            self.decoder = None
            # set and freeze decoder_pos_embed,  use the fixed sin-cos embedding for tokens + mask_token
            self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
                                                  requires_grad=False)
            self.decoder_blocks = nn.ModuleList([  # qk_scale=None fixme related to timm version
                Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
                for i in range(decoder_depth)])
            self.decoder_norm = norm_layer(decoder_embed_dim)

            # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
            self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)  # decoder to patch

        # --------------------------------------------------------------------------
        # this controls the puzzle group
        self.group_shuffle_size = group_shuffle_size

        # wether or not to use norm_pix_loss
        self.norm_pix_loss = norm_pix_loss
        # parameter initialization
        self.initialize_weights()

        # load basic state_dict of backbone for Transfer-learning-based tuning
        if basic_state_dict is not None:
            self.load_state_dict(basic_state_dict, False)

    def initialize_weights(self):
        # initialization
        # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches ** .5),
                                            cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        if self.decoder is None:
            # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
            decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                        int(self.patch_embed.num_patches ** .5),
                                                        cls_token=True)
            self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))  # xavier_uniform,让输入输出的方差相同,包括前后向传播

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        # torch.nn.init.normal_(self.prompt_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        # initialize nn.Linear and nn.LayerNorm
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs, patch_size=None):
        """
        Break image to patch tokens

        input:
        imgs: (B, 3, H, W)

        output:
        x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
        """
        # patch_size
        patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size

        # assert H == W and image shape is dividedable by patch
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0
        # patch num in rol or column
        h = w = imgs.shape[2] // patch_size

        # use reshape to split patch [B, C, H, W] -> [B, C, h_p, patch_size, w_p, patch_size]
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, patch_size, w, patch_size))

        # ReArrange dimensions [B, C, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, C]
        x = torch.einsum('nchpwq->nhwpqc', x)
        # ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim]
        x = x.reshape(shape=(imgs.shape[0], h * w, patch_size ** 2 * 3))
        return x

    def patchify_decoder(self, imgs, patch_size=None):
        """
        Break image to patch tokens

        fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default更合理

        input:
        imgs: (B, CLS, H, W)

        output:
        x: (B, num_patches, -1) AKA [B, num_patches, -1]
        """
        # patch_size
        patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size

        # assert H == W and image shape is divided-able by patch
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0
        # patch num in rol or column
        h = w = imgs.shape[2] // patch_size

        # use reshape to split patch [B, CLS, H, W] -> [B, CLS, h_p, patch_size, w_p, patch_size]
        x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size))

        # ReArrange dimensions [B, CLS, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, CLS]
        x = torch.einsum('nchpwq->nhwpqc', x)
        # ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim]
        x = x.reshape(shape=(imgs.shape[0], h * w, -1))
        return x

    def unpatchify(self, x, patch_size=None):
        """
        Decoding encoded patch tokens

        input:
        x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]

        output:
        imgs: (B, 3, H, W)
        """
        # patch_size
        p = self.patch_embed.patch_size[0] if patch_size is None else patch_size

        # squre root of num_patches (without CLS token is required)
        h = w = int(x.shape[1] ** .5)
        # assert num_patches is with out CLS token
        assert h * w == x.shape[1]

        # ReArrange dimensions [B, num_patches, flatten_dim] -> [B, h_p, w_p, patch_size, patch_size, C]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        # ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, C, h_p, patch_size, w_p, patch_size]
        x = torch.einsum('nhwpqc->nchpwq', x)
        # use reshape to compose patch [B, C, h_p, patch_size, w_p, patch_size] -> [B, C, H, W]
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def fix_position_shuffling(self, x, fix_position_ratio, puzzle_patch_size):
        """
        Fix-position shuffling

        Randomly assign patches by per-sample shuffling.
        After it, the fixed patches are reserved as Positional Tokens
        the rest patches are batch wise randomly shuffled among the batch since they serve as Relation Tokens.

        Per-sample shuffling is done by argsort random noise.
        batch wise shuffle operation is done by shuffle all idxes

        input:
        x: [B, 3, H, W], input image tensor
        fix_position_ratio  float
        puzzle_patch_size  int

        output: x_puzzled, mask
        x_puzzled: [B, 3, H, W]
        mask: [B, 3, H, W], binary mask indicating pix position with 0
        """
        # Break img into puzzle patches with the size of puzzle_patch_size  [B, num_puzzle_patches, D_puzzle]
        x = self.patchify(x, puzzle_patch_size)
        # output: x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
        B, num_puzzle_patches, D = x.shape

        # num of fix_position puzzle patches
        len_fix_position = int(num_puzzle_patches * fix_position_ratio)
        num_shuffled_patches = num_puzzle_patches - len_fix_position
        # create a noise tensor to prepare shuffle idx of puzzle patches
        noise = torch.rand(B, num_puzzle_patches, device=x.device)  # [B,num_puzzle_patches] noise in [0, 1]

        # 在Batch里面每个序列上获得noise tensor经过升序排列后原本位置的idx矩阵,(各自不同)
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        # 再对idx矩阵继续升序排列可获得:原始noise tensor的每个位置的排序顺位
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset 前面的是fix的,后面的是puzzle的
        ids_fix = ids_shuffle[:, :len_fix_position]  # [B,num_puzzle_patches] -> [B,fix_patches]
        # fix_patches=num_puzzle_patches * fix_position_ratio   len_fix_position
        ids_puzzle = ids_shuffle[:, len_fix_position:]  # [B,num_puzzle_patches] -> [B,puzzle_patches]
        # puzzle_patches=num_puzzle_patches*(1-fix_position_ratio)  num_shuffled_patches

        # set puzzle patch
        # ids_?.unsqueeze(-1).repeat(1, 1, D)
        # [B,?_patches] -> [B,?_patches,1] (at each place with the idx of ori patch) -> [B,?_patches,D]

        # torch.gather to select patche groups x_fixed of [B,fix_patches,D] and x_puzzle of [B,puzzle_patches,D]
        # 要保持的,batch中每个sample不一样
        x_fixed = torch.gather(x, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D))
        # 要shuffle的,batch中每个sample不一样
        x_puzzle = torch.gather(x, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D))

        # batch&patch-wise shuffle is needed else the restore will restore all puzzles
        if self.group_shuffle_size == -1 or self.group_shuffle_size == B:
            puzzle_shuffle_indices = torch.randperm(B * num_shuffled_patches, device=x.device, requires_grad=False)
        else:
            assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0
            # build [B//self.group_shuffle_size, num_puzzle_patches] noise in [0, 1]
            group_noise = torch.rand(B // self.group_shuffle_size, num_shuffled_patches * self.group_shuffle_size, device=x.device)
            # get shuffled index in each (num_shuffled_patches*group_shuffle)
            group_ids_shuffle = torch.argsort(group_noise, dim=1)
            # break the dim and add the group idx(in list), stack back to tensor
            group_ids_shuffle = torch.stack([group_ids_shuffle[i] +
                                             num_shuffled_patches * self.group_shuffle_size * i
                                             for i in range(B // self.group_shuffle_size)])
            # flattern to be idx for all (B * num_shuffled_patches)
            puzzle_shuffle_indices = group_ids_shuffle.view(-1)

        # 将0~B * num_shuffled_patches-1(包括0和B * num_shuffled_patches-1)随机打乱后获得的数字序列
        x_puzzle = x_puzzle.view(B * num_shuffled_patches, D)[puzzle_shuffle_indices].view(B, num_shuffled_patches, D)
        # 利用randperm获得的乱序序列对应batch内所有需要shuffle的部分进行打乱顺序,之后将其恢复为原本的划分batch
        # pack up all puzzle patches
        x = torch.cat([x_fixed, x_puzzle], dim=1)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, num_puzzle_patches, D], device=x.device, requires_grad=False)  # no grad
        mask[:, :len_fix_position, :] = 0  # set the first len_fix of tokens to 0,rest to 1

        # unshuffle to restore the fixed positions
        x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
        # torch.gather to generate restored binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))

        # unpatchify to obtain puzzle images and their mask
        x = self.unpatchify(x, puzzle_patch_size)
        mask = self.unpatchify(mask, puzzle_patch_size)

        return x, mask  # x_puzzled and mask

    def forward_puzzle(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32):
        """
        Transform the input images to puzzle images

        input:
        x: [B, 3, H, W], input image tensor
        fix_position_ratio  float
        puzzle_patch_size  int

        output: x_puzzled, mask
        x_puzzled: [B, 3, H, W]
        mask: [B, 3, H, W], binary mask indicating pix position with 0
        """
        x_puzzled, mask = self.fix_position_shuffling(imgs, fix_position_ratio, puzzle_patch_size)
        return x_puzzled, mask

    def forward_encoder(self, imgs):
        """
        :param imgs: [B, C, H, W], sequence of imgs

        :return: Encoder output: encoded tokens, mask position, restore idxs
        x: [B, num_patches, D], sequence of Tokens (including the cls token)
        CLS_token: [B, 1, D]
        """

        if self.prompt_mode is None:  # ViT
            # embed patches
            x = self.patch_embed(imgs)

            # add pos embed before concatenate the cls token
            x = x + self.pos_embed[:, 1:, :]

            # detatch puzzle for embed_puzzle output
            embed_puzzle = x.data.detach()

            # append cls token
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # batch fix
            x = torch.cat((cls_tokens, x), dim=1)

            # apply Transformer blocks
            for blk in self.blocks:
                x = blk(x)

        else:  # VPT
            x = self.patch_embed(imgs)
            # add pos embed before concatenate the cls token
            x = x + self.pos_embed[:, 1:, :]

            # detatch puzzle for embed_puzzle output
            embed_puzzle = x.data.detach()  # copy the embed original puzzle (for illustration)

            # append cls token
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # batch fix
            x = torch.cat((cls_tokens, x), dim=1)

            if self.VPT_type == "Deep":

                Prompt_Token_num = self.Prompt_Tokens.shape[1]

                for i in range(len(self.blocks)):
                    # concatenate Prompt_Tokens
                    Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
                    # firstly concatenate
                    x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
                    num_tokens = x.shape[1]
                    # lastly remove, a good trick
                    x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]

            else:  # self.VPT_type == "Shallow"
                Prompt_Token_num = self.Prompt_Tokens.shape[1]

                # concatenate Prompt_Tokens
                Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
                x = torch.cat((x, Prompt_Tokens), dim=1)
                num_tokens = x.shape[1]
                # A whole sequential process
                x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]

        # last norm of Transformer
        x = self.norm(x)

        CLS_token = x[:, :1, :]
        x = x[:, 1:, :]

        # Encoder output: encoded tokens, mask position, embed original puzzle (for illustration)
        return x, CLS_token, embed_puzzle

    def forward_decoder(self, x):
        """
        Decoder to reconstruct the puzzle image
        [B, 1 + num_patches, D_Encoder] -> [B, 1 + num_patches, D_Decoder] -> [B, num_patches, p*p*3]

        :param x: [B, 1 + num_patches, D_Encoder], sequence of Tokens (including the cls token)

        :return: Decoder output: reconstracted tokens
        x: [B, num_patches, patch_size ** 2 * in_chans], sequence of Patch Tokens
        """

        if self.decoder is None:
            # embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder]
            x = self.decoder_embed(x)
            # print(x.shape)
            # add pos embed
            x = x + self.decoder_pos_embed

            # apply Transformer blocks
            for blk in self.decoder_blocks:
                x = blk(x)
            x = self.decoder_norm(x)

            # Reconstruction projection
            x = self.decoder_pred(x)
            # remove cls token
            x = x[:, 1:, :]
            # print("x shape: ", x.shape)  # [B, N, p*p*3]

        else:
            # remove cls token
            x = x[:, 1:, :]
            # embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder]
            x = self.decoder_embed(x)
            # unpatchify to make image form [B, H, W, C]
            x = self.unpatchify(x)  # restore image by Encoder
            # apply decoder module to segment the output of encoder
            x = self.decoder(x)  # one-hot seg decoder [B, CLS, H, W]
            # the output of segmentation is transformed to [B, N, Dec]
            x = self.patchify_decoder(x)  # TODO 做一个有意义的设计
            # Convert the number of channels to match image for loss function
            x = self.decoder_pred(x)  # [B, N, Dec] -> [B, N, p*p*3]
            # print(x.shape)

        return x

    def forward_loss(self, imgs, pred, mask):
        """
        MSE loss for all patches towards the ori image

        Input:
        imgs: [B, 3, H, W], Encoder input image
        pred: [B, num_patches, p*p*3], Decoder reconstructed image
        mask: [B, num_patches, p*p*3], 0 is keep, 1 is puzzled

        """
        # print("pred shape: ", pred.shape)  # [64, 196, 768]
        # target imgs: [B, 3, H, W] -> [B, num_patches, p*p*3]
        target = self.patchify(imgs)
        # print("target shape: ", target.shape)  # [64, 196, 768]
        # use mask as a patch indicator [B, num_patches, D] -> [B, num_patches]
        mask = mask[:, :, 0]  # Binary mask, 1 for removed patches, 0 for reserved patches:

        if self.norm_pix_loss:  # Normalize the target image patches
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6) ** .5

        # MSE loss
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [B, num_patches], mean loss on each patch pixel

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches [B], scalar

        return loss

    def forward(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32, combined_pred_illustration=False):
        # STEP 1: Puzzle making
        # create puzzle images: [B, 3, H, W]
        imgs_puzzled, mask = self.forward_puzzle(imgs, fix_position_ratio, puzzle_patch_size)

        # Visualization of imgs_puzzled_patches sequence: [B, num_patches, p*p*3]
        imgs_puzzled_patches = self.patchify(imgs_puzzled)
        # here, latent crop size is automatically based on encoder embedding

        # STEP 2: Puzzle understanding
        # Encoder to obtain latent tokens and embed_puzzle: [B, num_patches, D]
        latent_puzzle, CLS_token, embed_puzzle = self.forward_encoder(imgs_puzzled)
        # VPT output size of more tokens ? currently use firstly-cat-lastly-remove so its fine

        # STEP 3: Puzzle restoring

        # step 3.(a) prepare decoder input indcator mask at the encoder output stage:
        mask_patches_pp3 = self.patchify(mask)  # mark relation tokens with 1 [B, num_patches, p*p*3]
        # here, latent crop size is automatically based on encoder embedding

        # Reassign mask indicator shape to the encoder output dim
        if mask_patches_pp3.shape[-1] != latent_puzzle.shape[-1]:
            # [B, num_patches, p*p*3] -> [B, num_patches, 1] -> [B, num_patches, D]
            mask_patches = mask_patches_pp3[:, :, :1].expand(-1, -1, latent_puzzle.shape[-1])
        else:
            mask_patches = mask_patches_pp3

        # anti_mask: [B, num_patches, D], binary mask indicating fix position with 1 instead of 0
        anti_mask = mask_patches * -1 + 1  # great trick to process positional operation with less calculation

        # Position hint
        # in mask, 0 is Position Tokens, therefore take only Relation Tokens
        latent_tokens = latent_puzzle * mask_patches  # take out relation tokens(latent_tokens here)
        # in anti_mask, 0 is Relation Tokens, therefore take only Position Tokens
        hint_tokens = embed_puzzle * anti_mask  # anti_mask to take hint_tokens (position tokens)
        # group decoder tokens: [B, num_patches, D]
        latent = latent_tokens + hint_tokens
        # append back the cls token at the first -> [B, 1+num_patches, D]
        x = torch.cat([CLS_token, latent], dim=1)

        # step 3.(b) Decoder to obtain Reconstructed image patches:
        # [B, 1+num_patches,D] -> [B, 1+num_patches, D_Decoder] -> [B, num_patches, p*p*3]
        pred = self.forward_decoder(x)

        # combined pred
        anti_mask_patches_pp3 = mask_patches_pp3 * -1 + 1  # fix position with 1, relation patches with 0
        hint_img_patches = imgs_puzzled_patches * anti_mask_patches_pp3
        pred_img_patches = pred * mask_patches_pp3  # mark relation tokens with 1, fix position with 0
        pred_with_hint_imgs = hint_img_patches + pred_img_patches

        # MSE loss for all patches towards the ori image
        loss = self.forward_loss(imgs, pred, mask_patches)
        # print(loss)  # check whether the loss is working

        if combined_pred_illustration:
            return loss, pred_with_hint_imgs, imgs_puzzled_patches
        else:
            return loss, pred, imgs_puzzled_patches


def sae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def sae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def sae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs):
    print("Decoder:", dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# decoder
def sae_vit_base_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


def sae_vit_large_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 768
        decoder_rep_dim = 16 * 16 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


def sae_vit_huge_patch14_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
    # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练

    if dec_idx == 'swin_unet':
        decoder_embed_dim = 14 * 14 * 3
        decoder_rep_dim = 14 * 14 * 3

        from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
        decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)

    elif dec_idx == 'transunet':
        decoder_embed_dim = 14 * 14 * 3
        decoder_rep_dim = 14 * 14 * 3

        transunet_name = 'R50-ViT-B_16'
        transunet_patches_size = 16
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
        from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg

        config_vit = CONFIGS_Transunet_seg[transunet_name]
        config_vit.n_classes = num_classes
        config_vit.n_skip = 3

        if transunet_name.find('R50') != -1:
            config_vit.patches.grid = (
                int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
        decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)

    elif dec_idx == 'UTNetV2':
        decoder_embed_dim = 14 * 14 * 3
        decoder_rep_dim = 14 * 14 * 3

        from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
        decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)

    else:
        print('no effective decoder!')
        return -1

    print('dec_idx: ', dec_idx)

    model = ShuffledAutoEncoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
        **kwargs)
    return model


# set recommended archs following MAE
sae_vit_base_patch16 = sae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
sae_vit_large_patch16 = sae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
sae_vit_huge_patch14 = sae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

# Equiped with decoders
sae_vit_base_patch16_decoder = sae_vit_base_patch16_dec  # decoder: 768 dim, HYF decoders
sae_vit_large_patch16_decoder = sae_vit_large_patch16_dec  # decoder: 768 dim, HYF decoders
sae_vit_huge_patch14_decoder = sae_vit_huge_patch14_dec  # decoder: 768 dim, HYF decoders

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    img_size = 224

    '''
    num_classes = 3  # set to 3 for 3 channel
    x = torch.rand(2, 3, img_size, img_size, device=device)
    '''

    image_tensor_path = './temp-tensors/color.pt'
    x = torch.load(image_tensor_path)
    x.to(device)

    # model = sae_vit_base_patch16(img_size=img_size, decoder=None)
    # model = sae_vit_huge_patch14(img_size=img_size, decoder=None)
    # model = sae_vit_base_patch16_decoder(prompt_mode="Deep", dec_idx='swin_unet', img_size=img_size)
    model = sae_vit_base_patch16(img_size=img_size, decoder=None, group_shuffle_size=2)

    '''
    # ViT_Prompt

    from pprint import pprint
    model_names = timm.list_models('*vit*')
    pprint(model_names)

    basic_model = timm.create_model('vit_base_patch' + str(16) + '_' + str(edge_size), pretrained=True)

    basic_state_dict = basic_model.state_dict()
                                        
    model = sae_vit_base_patch16(img_size=384, prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict)
    
    prompt_state_dict = model.obtain_prompt()
    VPT = VPT_ViT(img_size=384, VPT_type='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict)
    VPT.load_prompt(prompt_state_dict)
    VPT.to(device)
    pred = VPT(x)
    print(pred)
    '''

    model.to(device)

    loss, pred, imgs_puzzled_patches = model(x, fix_position_ratio=0.25, puzzle_patch_size=32,
                                             combined_pred_illustration=True)
    # combined_pred_illustration = True to add hint tokens at the pred, False to know more info


    # 可视化看看效果
    from utils.visual_usage import *

    imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16)
    for img_idx in range(len(imgs_puzzled_batch)):
        puzzled_img = imgs_puzzled_batch.cpu()[img_idx]
        puzzled_img = ToPILImage()(puzzled_img)
        puzzled_img.save(os.path.join('./temp-figs/', 'puzzled_sample_'+str(img_idx)+'.jpg'))

        recons_img_batch = unpatchify(pred, patch_size=16)
        recons_img = recons_img_batch.cpu()[img_idx]
        recons_img = ToPILImage()(recons_img)
        recons_img.save(os.path.join('./temp-figs/', 'recons_sample_'+str(img_idx)+'.jpg'))
    '''

    print(loss, '\n')

    print(loss.shape, '\n')

    print(pred.shape, '\n')

    print(imgs_puzzled_patches.shape, '\n')
    '''