File size: 37,848 Bytes
9cf79cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp

from util.pos_embed import get_2d_sincos_pos_embed

from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
import numpy as np
import scipy.stats as stats
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import conv3x3, subpel_conv3x3
import math
from torch import Tensor
from einops import rearrange, repeat
import torch.nn.functional as F
import torchac
from typing import Any, Callable, List, Optional, Tuple, Union

SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
    return torch.exp(torch.linspace(math.log(min), math.log(max), levels))

def ste_round(x: Tensor) -> Tensor:
    return torch.round(x) - x.detach() + x

def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )

def mask_by_random_topk(mask_len, probs, temperature=1.0):
    mask_len = mask_len.squeeze()
    # 使用Gumbel分布进行采样,增加随机性
    confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
    sorted_confidence, _ = torch.sort(confidence, axis=-1)
    # Obtains cut off threshold given the mask lengths.
    cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
    # Masks tokens with lower confidence.
    masking = (confidence <= cut_off)
    return masking

def adjust_mask_and_drop_embeddings(token_keep_mask):
    """
    Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
    some of them to False, and then applies this adjusted mask to input_embeddings.

    Parameters:
    - input_embeddings: Tensor, The embeddings tensor.
    - token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.

    Returns:
    - Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
    """
    # 获取非零(即值为True)元素的索引
    non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
    # 计算非零元素的数量
    non_zero_count = non_zero_indices[0].size(0)
    # 计算最近的整数平方倍
    next_square = math.floor(math.sqrt(non_zero_count))**2
    # 计算需要移除的元素数量
    remove_count = non_zero_count - next_square
    if remove_count > 0:
        # 如果需要移除元素以达到整数平方倍
        permuted_indices = torch.randperm(non_zero_count)[:remove_count]
        for idx in permuted_indices:
            token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
    # 使用更新后的token_keep_mask
    # input_embeddings_after_drop = input_embeddings[token_keep_mask]

    return token_keep_mask


class FactorizedEntropyModel(EntropyBottleneck):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
        if training is None:
            training = self.training

        # 输入形状已经是 [b, c, seq_len],无需转置
        shape = x.size()

        # Add noise or quantize
        means = self._get_medians()
        # outputs = self.quantize(
        #     x, "noise" if training else "dequantize", means.long()
        # )
        outputs = self.quantize(
            x, "dequantize", means.long()
        )

        if not torch.jit.is_scripting():
            likelihood = self._likelihood(outputs)
            if self.use_likelihood_bound:
                likelihood = self.likelihood_lower_bound(likelihood)
        else:
            raise NotImplementedError("TorchScript is not yet supported")

        return outputs, likelihood
    
    def compress(self, x):
        # 构建索引,适用于单通道序列数据
        indexes = self._build_indexes(x.size())
        # 获取中位数,已经适配为单通道
        medians = self._get_medians().detach()
        # 调整 medians 的形状以匹配 x 的形状
        medians = medians.expand_as(x)
        # 调用基类的 compress 方法进行压缩
        return super().compress(x, indexes, medians)

    def decompress(self, strings, size):
        # 预期的输出大小应包括单个通道
        output_size = (len(strings), 1, *size)  # 这里 size 应该是 seq_len
        # 构建索引
        indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
        # 获取中位数并调整其形状以匹配预期输出的形状
        medians = self._extend_ndims(self._get_medians().detach(), len(size))
        medians = medians.expand(len(strings), 1, *([-1] * len(size)))
        # 调用基类的 decompress 方法进行解压缩
        return super().decompress(strings, indexes, medians.dtype, medians)
    
    def _preprocess(self, x):
        x = x.permute(0, 2, 3, 1).contiguous()
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (3,B,num_heads,N,head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        with torch.cuda.amp.autocast(enabled=False):
            attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale

        attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
        return x, attn


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()   # drop_path=0
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)   # mlp_ratio=4
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1

    def forward(self, x, return_attention=False):
        if return_attention:
            _, attn = self.attn(self.norm1(x))
            return attn
        else:
            y, _ = self.attn(self.norm1(x))
            x = x + self.drop_path(y)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class LabelSmoothingCrossEntropy(nn.Module):
    """ NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
        # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
        # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用

        torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
        torch.nn.init.normal_(self.position_embeddings.weight, std=.02)

    def forward(
        self, input_ids
    ):
        input_shape = input_ids.size()  # input_ids: (B, N)(32,257)

        seq_length = input_shape[1]

        position_ids = self.position_ids[:, :seq_length]

        inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)

        position_embeddings = self.position_embeddings(position_ids)    # (1, seq_len, embed_dim)
        embeddings = inputs_embeds + position_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class MlmLayer(nn.Module):

    def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
        super().__init__()
        self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
        self.gelu = nn.GELU()
        self.ln = nn.LayerNorm(word_emb_dim)
        self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))

    def forward(self, x, word_embeddings):  # x: (b, seq_len, embed_dim)
        mlm_hidden = self.fc(x)
        mlm_hidden = self.gelu(mlm_hidden)
        mlm_hidden = self.ln(mlm_hidden)
        word_embeddings = word_embeddings.transpose(0, 1)
        logits = torch.matmul(mlm_hidden, word_embeddings)
        logits = logits + self.bias
        return logits   # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率


class MaskedGenerativeEncoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=256, patch_size=16, in_chans=3,     # need to change the default value of img_size
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
                 mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
                 vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
        super().__init__()

        # --------------------------------------------------------------------------
        # VQGAN specifics
        config = OmegaConf.load('config/vqgan.yaml').model
        self.vqgan = VQModel(ddconfig=config.params.ddconfig,
                             n_embed=config.params.n_embed, # 1024
                             embed_dim=config.params.embed_dim, # 256
                             ckpt_path=vqgan_ckpt_path)
        for param in self.vqgan.parameters():
            param.requires_grad = False

        self.codebook_size = config.params.n_embed  # 1024
        vocab_size = self.codebook_size + 1000 + 1  # 1024 codebook size, 1000 classes, 1 for mask token.
        self.fake_class_label = self.codebook_size + 1100 - 1024    # 1100
        self.mask_token_label = vocab_size - 1  # 2024
        self.token_emb = BertEmbeddings(vocab_size=vocab_size,  # 向量空间大小,1024个embedding + 1000 class + 1 mask token
                                        hidden_size=embed_dim,
                                        max_position_embeddings=img_size +1,
                                        # max_position_embeddings=256+1,  # 256个patch + 1 class token
                                        dropout=0.1)

        # MAGE variant masking ratio
        self.mask_ratio_min = mask_ratio_min
        self.mask_ratio_max = mask_ratio_max
        # self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
        #                                             (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
        #                                             loc=mask_ratio_mu, scale=mask_ratio_std)

        # --------------------------------------------------------------------------
        # MAGE encoder specifics
        dropout_rate = 0.1
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)    # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([   # encoder
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
                  drop=dropout_rate, attn_drop=dropout_rate)
            for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
        self.norm = norm_layer(embed_dim)   # layer norm
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAGE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))    # decoder_embed_dim=512
        self.pad_with_cls_token = True

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))  # learnable pos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
                  drop=dropout_rate, attn_drop=dropout_rate)
            for i in range(decoder_depth)])  # decoder_depth=8 for mage-vitb

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MlmLayer
        self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)

        self.norm_pix_loss = norm_pix_loss

        self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
        # --------------------------------------------------------------------------
        self.entropy_bottleneck = FactorizedEntropyModel(1)

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_sample_mask_rate(self):
        # 生成一个 (0, 1] 范围内的随机数
        random_sample = 1 - torch.rand(1)
        # 映射到 mask_ratio_min 到 mask_ratio_max 的范围
        mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
        return mask_rate.item()  # 转换为Python的标量值

    def get_cdf_token_mask(self, token_all_mask):
        bsz, seq_len = token_all_mask.size()
        # --- use Normal distribution.
        dist_normal = torch.distributions.Normal(0, 2)
        cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
        cdf_mask_token = (cdf_mask_token - .5) * 2 
        cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp', 
                                    b=bsz, s=seq_len)
                                    
        cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
        return cdf_mask_token
    
    # def get_cdf_token_mask(self, token_all_mask):
    #     bsz, seq_len = token_all_mask.size()
    #     # 直接生成一个0到1之间的线性空间
    #     linear_space = torch.linspace(0, 1, steps=seq_len+1)
    #     # 无需映射到-1到1
    #     cdf_mask_token = linear_space
    #     # 调整形状以匹配token_all_mask,并扩展到每个batch
    #     cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp', 
    #                                 b=bsz, s=seq_len)
    #     # cdf_mask_token = cdf_mask_token.unsqueeze(0).unsqueeze(-1).repeat(bsz, 1, seq_len)
    #     # 添加填充以匹配原始代码的操作
    #     cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
    #     return cdf_mask_token
    
    def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
        """
        input: x: (B, 3, H, W)
        """
        # ============ 1. tokenization ============ #
        with torch.no_grad():
            z_q, _, token_tuple = self.vqgan.encode(x)  # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
        
        _, _, token_indices = token_tuple   # token_indices: (B*H*W,)(8192)
        token_indices = token_indices.reshape(z_q.size(0), -1)  # token_indices: (B, H*W)
        gt_indices = token_indices.clone().detach().long()

        # ============ 2. masking process ============ # 
        bsz, seq_len = token_indices.size() # seq_len=h*w
        mask_ratio_min = self.mask_ratio_min    # 0.5

        if is_training:
            # mask_rate = self.mask_ratio_generator.rvs(1)[0]
            mask_rate = self.random_sample_mask_rate()
            num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
        else:
            num_dropped_tokens = 0
            if manual_mask_rate is not None:
                mask_rate = manual_mask_rate
            else:
                raise ValueError("mask_rate should be provided for inference!")
        
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask_ratio = num_masked_tokens / seq_len    # for calculate vbr lambda 
        # it is possible that two elements of the noise is the same, so do a while loop to avoid it
        while True:
            noise = torch.rand(bsz, seq_len, device=x.device)  # noise in [0, 1]
            sorted_noise, _ = torch.sort(noise, dim=1)  # ascend: small is remove, large is keep
            if num_dropped_tokens > 0:
                cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
            else:
                cutoff_drop = torch.zeros((bsz, 1), device=x.device)
            cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
            token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
            token_all_mask = (noise <= cutoff_mask).float()  # 逻辑上标记那些token是被mask掉的
            if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
                break
            else:
                print("Rerandom the noise!")

        # 获取unmasked token及其位置
        unmasked_pos = token_all_mask == 0  # 未被mask的位置
        unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1)  # 未被mask的token

        return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio

    def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
        bsz, seq_len = gt_indices.size()  
        padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
        # 将未被mask的token填充回去
        # 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
        unmasked_token_counter = [0 for _ in range(bsz)]
    
        for b in range(bsz):
            for idx in range(seq_len):
                # 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
                if (token_all_mask[b, idx] == 0):  # 检查是否未被mask
                    # 替换相应的unmaksed token
                    padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
                    # 更新计数器
                    unmasked_token_counter[b] += 1
        
        token_indices = padded_token_indices
        # ============ 3. Adding class token ============ #
        # concate class token, add [CLS] token to aggregate sequence-level representations
        token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = self.fake_class_label # [B, 257]
        # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token, 
        # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
        # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
        token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
        token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
        token_indices = token_indices.long()

        # ============ 4. Embedding and Dropout ============ #
        # bert embedding
        input_embeddings = self.token_emb(token_indices)    # get embeddings [B, 257, 768]
        # print("Input embedding shape:", input_embeddings.shape)
        bsz, seq_len, emb_dim = input_embeddings.shape

        # dropping
        token_keep_mask = 1 - token_drop_mask
        input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
        # print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
        
        # ============ 5. Transformer encoding ============ #
        x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim)    # 32, 129, 768
        for blk in self.blocks:
            x = blk(x)  # each block has a multi-head self-attention and a mlp
        x = self.norm(x)
        # print("Encoder representation shape:", x.shape)

        return x, token_indices, token_all_mask, token_drop_mask

    def forward_decoding(self, x, token_drop_mask, token_all_mask):
        """
            x: output x of forward_encoder()
            token_drop_mask: positions for dropped tokens
            token_all_mask: positions for masked tokens
        """
        # ============ 1. Prepare Embedding and padding tokens ============ #
        # embed tokens
        x = self.decoder_embed(x)   # input_embedding_after_padding

        # append mask tokens to sequence
        # replicates the [CLS] token embedding across the sequence length where masking is to be applied
        if self.pad_with_cls_token: # True
            mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
        else:
            mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)

        # ============ 2. Prepare positional embedding ============ #
        # put undropped tokens into original sequence
        x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
        x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
        # set undropped but masked positions with mask
        x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad)    # 被drop的也padding

        # add pos embed
        x = x_after_pad + self.decoder_pos_embed_learned    # add learnable pos embedding

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)

        x = self.decoder_norm(x)

        word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
        logits = self.mlm_layer(x, word_embeddings)  #  produce predictions for masked tokens
        print("Logits shape:", logits.shape)

        return logits

    def forward_loss(self, gt_indices, logits, mask):
        bsz, seq_len = gt_indices.size()
        # logits and mask are with seq_len+1 but gt_indices is with seq_len
        loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
        loss = loss.reshape(bsz, seq_len)
        loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum()  # mean loss on removed patches
        return loss

    def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
        lmbda = A * torch.exp(B * (1 - mask_ratio))
        return lmbda
    
    def cal_loss(self, logits, gt_indices, mask, mask_ratio):
        mask_ratio = torch.tensor(mask_ratio)
        ## cal cross entropy loss
        task_loss = self.forward_loss(gt_indices, logits, mask)
        lmbda = self.cal_lmbda(mask_ratio)
        ## cal total loss for codec optimization
        return task_loss, lmbda

    def forward(self, imgs, is_training=False, manual_mask_rate=None):
        ## ---------- encoding process ---------- ##
        gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio = self.pre_encoding(imgs, is_training, manual_mask_rate)
        latent = latent.unsqueeze(1)

        latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
        # 判断latent_hat和latent是否相等
        # print((latent_hat == latent).all())
        cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
        sym = (token_all_mask.short() + 1).cpu()
        bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
        mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=16, w=16).unsqueeze(1)

        ## ---------- decoding process ---------- ##
        decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
        decoded_mask = (decoded_sym - 1).to(device=imgs.device)
        latent_hat = latent_hat.squeeze(1)
        x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
        logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
        ## calculate loss
        task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
        return_dict = {
            'logits': logits,
            'likelihoods': latent_likelihoods,
            'task_loss': task_loss,
            'token_indices': token_indices,
            'token_all_mask': token_all_mask,
            'bs_mask_token': bs_mask_token,
            'mask_ratio': mask_ratio,
            'lambda': lmbda,
            'mask_vis': 1 - mask_vis,
        }
        return return_dict

    # def update(self, scale_table=None, force=False):
    #     if scale_table is None:
    #         scale_table = get_scale_table()
    #     updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
    #     updated |= super().update(force=force)
    #     return updated
    
    def gen_img(self, logits, token_all_mask, token_indices, num_iter=12, choice_temperature=4.5):
        """
        generated image at inference
            seed: random seed
            logits: predicted logits by model decoder
            token_all_mask: mask token indices
            token_indices: token indices of the input image after the vq tokenizer
            num_iter: number of iterations for sampling
            choice_temperature: temperature for sampling
        """
        # torch.manual_seed(seed)
        # np.random.seed(seed)
        bsz = logits.size(0)
        codebook_emb_dim = 256
        codebook_size = 1024
        mask_token_id = self.mask_token_label
        _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
        unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
        for step in range(num_iter):
            if step == 0:
                cur_ids = token_indices.clone().long()  # token_indices represent the current state of the sequence(unmasked tokens)
                cur_ids = cur_ids[:, 1:]  # 从第二列开始到最后一列
                logits = logits[:, 1:, :codebook_size]
                # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0. 
                # For iter=6, we use categorical sampling and temp=4.5." 
                sample_dist = torch.distributions.categorical.Categorical(logits=logits)
                sampled_ids = sample_dist.sample()  # sampled_ids = torch.argmax(logits, dim=-1)
                # get ids for next step
                # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
                # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
                # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
                unknown_map = (cur_ids == mask_token_id)    
                sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
                # Defines the mask ratio for the next round. The number to mask out is
                # determined by mask_ratio * unknown_number_in_the_beginning.
                ratio = 1. * (step + 1) / num_iter
                mask_ratio = np.cos(math.pi / 2. * ratio)   # ratio = cosine(Π/2 * i/num_iter)

                # sample ids according to prediction confidence
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)

                selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
                unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
                mask_ratio = torch.tensor(mask_ratio).cuda()
                # mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
                mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
                # Keeps at least one of prediction in this round and also masks out at least
                # one and for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))

                # Sample masking tokens for next iteration
                masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
                # Masks tokens with lower confidence.
                token_indices = torch.where(masking, mask_token_id, sampled_ids)
            else:
                cur_ids = token_indices.clone().long()  # .long(): to int64
                token_indices = torch.cat(
                    [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
                token_indices[:, 0] = self.fake_class_label
                token_indices = token_indices.long()
                token_all_mask = token_indices == mask_token_id

                token_drop_mask = torch.zeros_like(token_indices)

                # token embedding
                input_embeddings = self.token_emb(token_indices)   # get input embeddings

                # encoder
                x = input_embeddings
                for blk in self.blocks:
                    x = blk(x)
                x = self.norm(x)

                # decoder
                logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
                logits = logits[:, 1:, :codebook_size]  # remove the cls token and dims > codebook_size

                # get token prediction
                # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0. 
                # For iter=6, we use categorical sampling and temp=4.5." 
                sample_dist = torch.distributions.categorical.Categorical(logits=logits)
                sampled_ids = sample_dist.sample()  # sampled_ids = torch.argmax(logits, dim=-1)

                # get ids for next step
                # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
                # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
                # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
                unknown_map = (cur_ids == mask_token_id)    
                sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
                # Defines the mask ratio for the next round. The number to mask out is
                # determined by mask_ratio * unknown_number_in_the_beginning.
                ratio = 1. * (step + 1) / num_iter

                mask_ratio = np.cos(math.pi / 2. * ratio)   # ratio = cosine(Π/2 * i/num_iter)

                # sample ids according to prediction confidence
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)

                selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
                unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
                mask_ratio = torch.tensor(mask_ratio).cuda()
                mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
                # Keeps at least one of prediction in this round and also masks out at least
                # one and for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))

                # Sample masking tokens for next iteration
                masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
                # Masks tokens with lower confidence.
                token_indices = torch.where(masking, mask_token_id, sampled_ids)
            
        # vqgan visualization
        z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim))
        gen_images = self.vqgan.decode(z_q)
        return gen_images


def mage_vit_base_patch16(**kwargs):
    model = MaskedGenerativeEncoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mage_vit_large_patch16(**kwargs):
    model = MaskedGenerativeEncoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model