File size: 39,665 Bytes
9cf79cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp

from util.pos_embed import get_2d_sincos_pos_embed

from taming.models.vqgan import VQModel, VQModel_w_Prompt
from omegaconf import OmegaConf
import numpy as np
import scipy.stats as stats
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import conv3x3, subpel_conv3x3
import math
from torch import Tensor
from einops import rearrange, repeat
import torch.nn.functional as F
import torchac
from typing import Any, Callable, List, Optional, Tuple, Union
from util.utils import adaptively_split_and_pad, crop_and_reconstruct
from timm.models.layers import trunc_normal_
from util.rle import rle_encode, rle_decode


def mask_by_random_topk(batch_mask_len, batch_probs, temperature, token_all_mask=None):
    """ Mask by random top-k operation for a single batch """
    gumbel_noise = torch.Tensor(temperature * np.random.gumbel(size=batch_probs.shape)).to(batch_probs.device)
    confidence = torch.log(batch_probs) + gumbel_noise

    if token_all_mask is not None:
        high_confidence = torch.max(confidence) + 1  # Set a very high confidence for known tokens
        confidence = torch.where(token_all_mask, high_confidence, confidence)

    sorted_confidence, _ = torch.sort(confidence, axis=-1)
    # Obtains cut off threshold given the mask lengths.
    cut_off = sorted_confidence[:, batch_mask_len.long()-1:batch_mask_len.long()]
    # Masks tokens with lower confidence.
    masking = (confidence <= cut_off)
    return masking


class FactorizedEntropyModel(EntropyBottleneck):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
        if training is None:
            training = self.training

        # 输入形状已经是 [b, c, seq_len],无需转置
        shape = x.size()

        # Add noise or quantize
        means = self._get_medians()
        # outputs = self.quantize(
        #     x, "noise" if training else "dequantize", means.long()
        # )
        outputs = self.quantize(
            x, "dequantize", means.long()
        )

        if not torch.jit.is_scripting():
            likelihood = self._likelihood(outputs)
            if self.use_likelihood_bound:
                likelihood = self.likelihood_lower_bound(likelihood)
        else:
            raise NotImplementedError("TorchScript is not yet supported")

        return outputs, likelihood
    
    def compress(self, x):
        # 构建索引,适用于单通道序列数据
        indexes = self._build_indexes(x.size())
        # 获取中位数,已经适配为单通道
        medians = self._get_medians().detach()
        # 调整 medians 的形状以匹配 x 的形状
        medians = medians.expand_as(x)
        # 调用基类的 compress 方法进行压缩
        return super().compress(x, indexes, medians)

    def decompress(self, strings, size):
        # 预期的输出大小应包括单个通道
        output_size = (len(strings), 1, *size)  # 这里 size 应该是 seq_len
        # 构建索引
        indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
        # 获取中位数并调整其形状以匹配预期输出的形状
        medians = self._extend_ndims(self._get_medians().detach(), len(size))
        medians = medians.expand(len(strings), 1, *([-1] * len(size)))
        # 调用基类的 decompress 方法进行解压缩
        return super().decompress(strings, indexes, medians.dtype, medians)
    
    def _preprocess(self, x):
        x = x.permute(0, 2, 3, 1).contiguous()
        return x


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


try:
    from apex.normalization import FusedLayerNorm
except:
    FusedLayerNorm = LayerNorm


class ImportancePredictor(nn.Module):
    """
    Input: z_q: [b, (h*w), c]
    Output: importance_score: [b, N]
    """
    def __init__(self, embed_dim=768):  # 768
        super().__init__()
        self.in_conv = nn.Sequential(
            FusedLayerNorm(embed_dim, eps=1e-5),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )

        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 1)
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, policy):
        x = self.in_conv(x)
        B, N, C = x.size()
        local_x = x[:, :, :C // 2]
        global_x = (x[:, :, C // 2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
        x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1)
        x = self.out_conv(x)
        return x.squeeze(-1)  # 将形状从 [b, N, 1] 转换为 [b, N]


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (3,B,num_heads,N,head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        with torch.cuda.amp.autocast(enabled=False):
            attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale

        attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
        return x, attn


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()   # drop_path=0
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)   # mlp_ratio=4
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1

    def forward(self, x, return_attention=False):
        if return_attention:
            _, attn = self.attn(self.norm1(x))
            return attn
        else:
            y, _ = self.attn(self.norm1(x))
            x = x + self.drop_path(y)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class LabelSmoothingCrossEntropy(nn.Module):
    """ NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
        # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
        # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用

        torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
        torch.nn.init.normal_(self.position_embeddings.weight, std=.02)

    def forward(
        self, input_ids
    ):
        input_shape = input_ids.size()  # input_ids: (B, N)(32,257)

        seq_length = input_shape[1]

        position_ids = self.position_ids[:, :seq_length]

        inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)

        position_embeddings = self.position_embeddings(position_ids)    # (1, seq_len, embed_dim)
        embeddings = inputs_embeds + position_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class MlmLayer(nn.Module):

    def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
        super().__init__()
        self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
        self.gelu = nn.GELU()
        self.ln = nn.LayerNorm(word_emb_dim)
        self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))

    def forward(self, x, word_embeddings):  # x: (b, seq_len, embed_dim)
        mlm_hidden = self.fc(x)
        mlm_hidden = self.gelu(mlm_hidden)
        mlm_hidden = self.ln(mlm_hidden)
        word_embeddings = word_embeddings.transpose(0, 1)
        logits = torch.matmul(mlm_hidden, word_embeddings)
        logits = logits + self.bias
        return logits   # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率


class MaskedGenerativeEncoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=256, patch_size=16, in_chans=3,     # need to change the default value of img_size
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
                 mask_ratio_min=0.5, mask_ratio_max=0.8, vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
        super().__init__()
        # --------------------------------------------------------------------------
        # VQGAN specifics
        config = OmegaConf.load('config/vqgan.yaml').model
        # self.vqgan = VQModel(ddconfig=config.params.ddconfig,
        #                      n_embed=config.params.n_embed, # 1024
        #                      embed_dim=config.params.embed_dim, # 256
        #                      ckpt_path=vqgan_ckpt_path)
        self.vqgan = VQModel_w_Prompt(ddconfig=config.params.ddconfig,
                             n_embed=config.params.n_embed, # 1024
                             embed_dim=config.params.embed_dim, # 256
                             ckpt_path=vqgan_ckpt_path)
        for param in self.vqgan.parameters():
            param.requires_grad = False

        self.codebook_size = config.params.n_embed  # 1024
        vocab_size = self.codebook_size + 1000 + 1  # 1024 codebook size, 1000 classes, 1 for mask token.
        self.fake_class_label = self.codebook_size + 1100 - 1024    # 1100
        self.mask_token_label = vocab_size - 1  # 2024
        self.token_emb = BertEmbeddings(vocab_size=vocab_size,  # 向量空间大小,1024个embedding + 1000 class + 1 mask token
                                        hidden_size=embed_dim,
                                        max_position_embeddings=img_size +1,
                                        # max_position_embeddings=256+1,  # 256个patch + 1 class token
                                        dropout=0.1)

        # MAGE variant masking ratio
        self.mask_ratio_min = mask_ratio_min
        self.mask_ratio_max = mask_ratio_max

        # --------------------------------------------------------------------------
        # MAGE encoder specifics
        dropout_rate = 0.1
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)    # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.token_predictor = ImportancePredictor(config.params.embed_dim)   # predict importance tokens

        self.blocks = nn.ModuleList([   # encoder
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
                  drop=dropout_rate, attn_drop=dropout_rate)
            for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
        self.norm = norm_layer(embed_dim)   # layer norm

        # --------------------------------------------------------------------------
        # MAGE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))    # decoder_embed_dim=512
        self.pad_with_cls_token = True

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))  # learnable pos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
                  drop=dropout_rate, attn_drop=dropout_rate)
            for i in range(decoder_depth)])  # decoder_depth=8 for mage-vitb

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MlmLayer
        self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
        self.norm_pix_loss = norm_pix_loss
        self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
        # --------------------------------------------------------------------------
        self.entropy_bottleneck = FactorizedEntropyModel(1)

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_sample_mask_rate(self):
        # 生成一个 (0, 1] 范围内的随机数
        random_sample = 1 - torch.rand(1)
        # 映射到 mask_ratio_min 到 mask_ratio_max 的范围
        mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
        return mask_rate.item()  # 转换为Python的标量值

    def get_cdf_token_mask(self, token_all_mask):
        bsz, seq_len = token_all_mask.size()
        # --- use Normal distribution.
        dist_normal = torch.distributions.Normal(0, 2)
        cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
        cdf_mask_token = (cdf_mask_token - .5) * 2 
        cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp', 
                                    b=bsz, s=seq_len)
                                    
        cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
        return cdf_mask_token

    def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
        """
        input: x: (B, 3, H, W)
        """
        # ============ 1. tokenization ============ #
        with torch.no_grad():
            z_q, _, token_tuple = self.vqgan.encode(x)  # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
        
        _, _, z_H, z_W = z_q.size()
        _, _, token_indices = token_tuple   # token_indices: (B*H*W,)(8192)
        token_indices = token_indices.reshape(z_q.size(0), -1)  # token_indices: (B, H*W)
        gt_indices = token_indices.clone().detach().long()  # gt_indcies: [b, seq_len]

        # ============ 2. masking process ============ # 
        bsz, seq_len = token_indices.size() # seq_len=h*w
        mask_ratio_min = self.mask_ratio_min    # 0.5

        if is_training:
            mask_rate = self.random_sample_mask_rate()
            num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
        else:
            num_dropped_tokens = 0
            if manual_mask_rate is not None:
                mask_rate = manual_mask_rate
            else:
                raise ValueError("mask_rate should be provided for inference!")

        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask_ratio = num_masked_tokens / seq_len    # for calculate vbr lambda 

        # it is possible that two elements of the noise is the same, so do a while loop to avoid it
        while True:
            noise = torch.rand(bsz, seq_len, device=x.device)  # noise in [0, 1]
            sorted_noise, _ = torch.sort(noise, dim=1)  # ascend: small is remove, large is keep
            if num_dropped_tokens > 0:
                cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
            else:
                cutoff_drop = torch.zeros((bsz, 1), device=x.device)
            cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
            token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
            token_all_mask = (noise <= cutoff_mask).float()  # 逻辑上标记那些token是被mask掉的
            if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
                break
            else:
                print("Rerandom the noise!")

        # 获取unmasked token及其位置
        unmasked_pos = token_all_mask == 0  # 未被mask的位置, [bsz, seq_len]元素为true和false
        unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1)  # 未被mask的token [b, unmaksed_seq_len]

        return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W

    def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, z_H, z_W):
        bsz, seq_len = gt_indices.size()  
        padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
        # 将未被mask的token填充回去
        # 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
        unmasked_token_counter = [0 for _ in range(bsz)]
    
        for b in range(bsz):
            for idx in range(seq_len):
                # 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
                if (token_all_mask[b, idx] == 0):  # 检查是否未被mask
                    # 替换相应的unmaksed token
                    padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
                    # 更新计数器
                    unmasked_token_counter[b] += 1
        
        token_indices = padded_token_indices    # [b0, z_H*z_W]
         # ============ 3. Adding class token ============ #
        ## ============ padding to mutiple of 16 by mask_token_label  ============ ##
        b0 = token_indices.size(0)
        token_indices = token_indices.reshape(b0, z_H, z_W).unsqueeze(1)    # reshape 成图像
        ori_shape = token_indices.shape
        # split成小patch: [new_bsz, c, 16, 16]
        token_indices, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(token_indices, self.mask_token_label)
        # ori_shape = token_indices.shape
        # token_indices = pad_to_multiple_of_16(token_indices, self.mask_token_label)
        # _, _, new_h, new_w = token_indices.shape # 得到padding后的图像宽高
        # token_indices = split_into_blocks(token_indices)    # 拆成小patch,batch size变大了
        new_bsz = token_indices.size(0)     # 获得拆分后的batch size
        token_indices = token_indices.squeeze(1).reshape(new_bsz, -1)   # reshape到我要处理的形式: [new_bsz, 256]
        # new_shape = [new_bsz, 1, new_h, new_w]

        ## ========== padding gt_indices for cal loss
        gt_indices = gt_indices.reshape(b0, z_H, z_W).unsqueeze(1)
        gt_indices, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(gt_indices, self.mask_token_label)
        gt_indices = gt_indices.squeeze(1).reshape(new_bsz, -1)
        # padding token_all_mask
        token_all_mask = token_all_mask.reshape(b0, z_H, z_W).unsqueeze(1)
        token_all_mask, _, _, _ = adaptively_split_and_pad(token_all_mask, self.mask_token_label)
        # token_all_mask = pad_to_multiple_of_16(token_all_mask, 1)
        # token_all_mask = split_into_blocks(token_all_mask)
        token_all_mask = token_all_mask.squeeze(1).reshape(new_bsz, -1)

        # padding token_drop_mask
        token_drop_mask = token_drop_mask.reshape(b0, z_H, z_W).unsqueeze(1)
        token_drop_mask, _, _, _ = adaptively_split_and_pad(token_drop_mask, self.mask_token_label)
        # token_drop_mask = pad_to_multiple_of_16(token_drop_mask, 1)
        # token_drop_mask = split_into_blocks(token_drop_mask)
        token_drop_mask = token_drop_mask.squeeze(1).reshape(new_bsz, -1)

        # ============ 3. Adding class token ============ #
        # concate class token, add [CLS] token to aggregate sequence-level representations
        token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = self.fake_class_label # [B, 257]
        # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token, 
        # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
        # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
        token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
        token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
        token_indices = token_indices.long()

        # ============ 4. Embedding and Dropout ============ #
        # bert embedding
        input_embeddings = self.token_emb(token_indices)    # get embeddings [B, 257, 768]
        bsz, seq_len, emb_dim = input_embeddings.shape

        # dropping
        token_keep_mask = 1 - token_drop_mask
        input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
        
        # ============ 5. Transformer encoding ============ #
        x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim)    # 32, 129, 768
        for blk in self.blocks:
            x = blk(x)  # each block has a multi-head self-attention and a mlp
        x = self.norm(x)

        # return x, token_indices, token_all_mask, token_drop_mask, ori_shape, new_shape
        return x, token_indices, token_all_mask, token_drop_mask, patch_sizes, num_blocks_h, num_blocks_w, ori_shape, gt_indices

    def forward_decoding(self, x, token_drop_mask, token_all_mask):
        """
            x: output x of forward_encoder()
            token_drop_mask: positions for dropped tokens
            token_all_mask: positions for masked tokens
        """
        # ============ 1. Prepare Embedding and padding tokens ============ #
        # embed tokens
        x = self.decoder_embed(x)   # input_embedding_after_padding

        # append mask tokens to sequence
        # replicates the [CLS] token embedding across the sequence length where masking is to be applied
        if self.pad_with_cls_token: # True
            mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
        else:
            mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)

        # ============ 2. Prepare positional embedding ============ #
        # put undropped tokens into original sequence
        x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
        x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
        # set undropped but masked positions with mask
        x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad)    # 被drop的也padding

        # add pos embed
        x = x_after_pad + self.decoder_pos_embed_learned    # add learnable pos embedding

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)

        x = self.decoder_norm(x)

        word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
        logits = self.mlm_layer(x, word_embeddings)  #  produce predictions for masked tokens
        # print("Logits shape:", x.shape)

        return logits

    def forward_loss(self, gt_indices, logits, mask):
        bsz, seq_len = gt_indices.size()
        # logits and mask are with seq_len+1 but gt_indices is with seq_len
        loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
        loss = loss.reshape(bsz, seq_len)
        loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum()  # mean loss on removed patches
        return loss

    def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
        lmbda = A * torch.exp(B * (1 - mask_ratio))
        return lmbda
    
    def cal_loss(self, logits, gt_indices, mask, mask_ratio):
        mask_ratio = torch.tensor(mask_ratio)
        ## cal cross entropy loss
        task_loss = self.forward_loss(gt_indices, logits, mask)
        lmbda = self.cal_lmbda(mask_ratio)
        ## cal total loss for codec optimization
        return task_loss, lmbda

    def forward(self, imgs, is_training=False, manual_mask_rate=None):
        ## ---------- encoding process ---------- ##
        gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
        latent = latent.unsqueeze(1)

        latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
        # 判断latent_hat和latent是否相等
        # print((latent_hat == latent).all())
        mask_stream, mask_len = rle_encode(token_all_mask)
        mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)

        ## ---------- decoding process ---------- ##
        decoded_mask = rle_decode(mask_stream, token_all_mask.shape).float()
        decoded_mask = decoded_mask.cuda()
        # print((decoded_mask == token_all_mask).all())
        latent_hat = latent_hat.squeeze(1)
        # x, token_indices, token_all_mask, token_drop_mask, ori_shape, new_shape = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask, z_H, z_W)
        x, token_indices, token_all_mask, token_drop_mask, patch_sizes, num_blocks_h, num_blocks_w, ori_shape, gt_indices = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask, z_H, z_W)
        logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
        ## calculate loss
        # task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)

        return_dict = {
            'logits': logits,
            'likelihoods': latent_likelihoods,
            # 'task_loss': task_loss,
            'token_indices': token_indices,
            'token_all_mask': token_all_mask,
            # 'bs_mask_token': bs_mask_token,
            'mask_len': mask_len,
            'mask_ratio': mask_ratio,
            # 'lambda': lmbda,
            'mask_vis': 1 - mask_vis,
            'ori_shape': ori_shape,
            'patch_sizes': patch_sizes,
            'num_blocks_h': num_blocks_h, 
            'num_blocks_w': num_blocks_w,
        }
        return return_dict
    
    def gen_img(self, logits, token_all_mask, token_indices, ori_shape, patch_sizes, num_blocks_h, num_blocks_w, num_iter=6, choice_temperature=4.5):
        """
        generated image at inference
            seed: random seed
            logits: predicted logits by model decoder
            token_all_mask: mask token indices
            token_indices: token indices of the input image after the vq tokenizer
            num_iter: number of iterations for sampling
            choice_temperature: temperature for sampling
        """
        bsz = logits.size(0)  # new_bsz
        codebook_emb_dim = 256
        codebook_size = 1024
        mask_token_id = self.mask_token_label
        _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf

        for step in range(num_iter):
            if step == 0:
                # print("enter in step==0")
                cur_ids = token_indices.clone().long()  # token_indices represent the current state of the sequence(unmasked tokens)
                cur_ids = cur_ids[:, 1:]  # 从第二列开始到最后一列
                logits = logits[:, 1:, :codebook_size]

                sample_dist = torch.distributions.categorical.Categorical(logits=logits)
                sampled_ids = sample_dist.sample()  # sampled_ids = torch.argmax(logits, dim=-1)
                # get ids for next step
                unknown_map = (cur_ids == mask_token_id)    # unknown_map表示padding或被mask掉的位置
                sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)

                # Defines the mask ratio for the next round. The number to mask out is
                # determined by mask_ratio * unknown_number_in_the_beginning.
                ratio = 1. * (step + 1) / num_iter
                mask_ratio = np.cos(math.pi / 2. * ratio)   # ratio = cosine(Π/2 * i/num_iter)
                unknown_number_in_the_beginning = torch.sum(unknown_map, dim=-1, keepdims=True).float()
                # print('begin unknown:', unknown_number_in_the_beginning)
                unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
                mask_ratio = torch.tensor(mask_ratio).cuda()
                mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
                # Keeps at least one of prediction in this round and also masks out at least
                # one and for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
                # print('len=', mask_len)

                # sample ids according to prediction confidence
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)

                selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()

                # Sample masking tokens for next iteration
                masking = mask_by_random_topk(torch.min(mask_len), selected_probs, choice_temperature * (1 - ratio))

                is_subset = (masking <= unknown_map).all()
                # print("token_all_mask是masking的子集:", is_subset.item())
                # Masks tokens with lower confidence.
                token_indices = torch.where(masking, mask_token_id, sampled_ids)    # 已知位置用sampled_ids代替
            else:
                # print("enter in step > 0")
                cur_ids = token_indices.clone().long()  # .long(): to int64
                token_indices = torch.cat(
                    [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
                token_indices[:, 0] = self.fake_class_label
                token_indices = token_indices.long()
                token_all_mask = (token_indices == mask_token_id)

                token_drop_mask = torch.zeros_like(token_indices)

                # token embedding
                input_embeddings = self.token_emb(token_indices)   # get input embeddings

                # encoder
                x = input_embeddings
                for blk in self.blocks:
                    x = blk(x)
                x = self.norm(x)

                # decoder
                logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
                logits = logits[:, 1:, :codebook_size]  # remove the cls token and dims > codebook_size

                # get token prediction
                sample_dist = torch.distributions.categorical.Categorical(logits=logits)
                sampled_ids = sample_dist.sample()  # sampled_ids = torch.argmax(logits, dim=-1)

                # get ids for next step
                # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
                unknown_map = (cur_ids == mask_token_id)
                sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)    # 填充sampled_ids
                # Defines the mask ratio for the next round. The number to mask out is
                # determined by mask_ratio * unknown_number_in_the_beginning.
                ratio = 1. * (step + 1) / num_iter
                mask_ratio = np.cos(math.pi / 2. * ratio)   # ratio = cosine(Π/2 * i/num_iter)
                unknown_number_in_the_beginning = torch.sum(unknown_map, dim=1, keepdims=True).float()
                unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
                mask_ratio = torch.tensor(mask_ratio).cuda()
                mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
                # Keeps at least one of prediction in this round and also masks out at least
                # one and for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))

                # sample ids according to prediction confidence
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
                selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()

                # Sample masking tokens for next iteration
                masking = mask_by_random_topk(torch.min(mask_len), selected_probs, choice_temperature * (1 - ratio))
                # Masks tokens with lower confidence.
                token_indices = torch.where(masking, mask_token_id, sampled_ids)
        
        # new_bsz, _, new_h, new_w = new_shape
        # b0, _, h0, w0 = ori_shape
        # print("sampled_ids.shape=", sampled_ids.shape)
        sampled_ids = sampled_ids.reshape(bsz, 16, 16).unsqueeze(1)    # 变回拆成的小patch形式
        # print('sampled_ids.shape:', sampled_ids.shape)
        sampled_ids = crop_and_reconstruct(sampled_ids, patch_sizes, num_blocks_h, num_blocks_w)
        # print("sampled_ids.shape=", sampled_ids.shape)
        sampled_ids = sampled_ids.reshape(ori_shape[0], -1)
        # print("sampled_ids.shape=", sampled_ids.shape)
        # vqgan visualization
        z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(ori_shape[0], ori_shape[2], ori_shape[3], codebook_emb_dim))
        gen_images = self.vqgan.decode(z_q)
        return gen_images

def mage_vit_base_patch16(**kwargs):
    model = MaskedGenerativeEncoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mage_vit_large_patch16(**kwargs):
    model = MaskedGenerativeEncoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model