from functools import partial import torch import torch.nn as nn from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp from util.pos_embed import get_2d_sincos_pos_embed from taming.models.vqgan import VQModel from omegaconf import OmegaConf import numpy as np import scipy.stats as stats from compressai.entropy_models import EntropyBottleneck from compressai.layers import conv3x3, subpel_conv3x3 import math from torch import Tensor from einops import rearrange, repeat import torch.nn.functional as F import torchac from typing import Any, Callable, List, Optional, Tuple, Union from timm.models.layers import trunc_normal_ from util.rle import rle_encode, rle_decode def mask_by_random_topk(mask_len, probs, temperature=1.0): mask_len = mask_len.squeeze() # 使用Gumbel分布进行采样,增加随机性 confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda() sorted_confidence, _ = torch.sort(confidence, axis=-1) # Obtains cut off threshold given the mask lengths. cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()] # Masks tokens with lower confidence. masking = (confidence <= cut_off) return masking class FactorizedEntropyModel(EntropyBottleneck): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]: if training is None: training = self.training # 输入形状已经是 [b, c, seq_len],无需转置 shape = x.size() # Add noise or quantize means = self._get_medians() outputs = self.quantize( x, "dequantize", means.long() ) if not torch.jit.is_scripting(): likelihood = self._likelihood(outputs) if self.use_likelihood_bound: likelihood = self.likelihood_lower_bound(likelihood) else: raise NotImplementedError("TorchScript is not yet supported") return outputs, likelihood def compress(self, x): # 构建索引,适用于单通道序列数据 indexes = self._build_indexes(x.size()) # 获取中位数,已经适配为单通道 medians = self._get_medians().detach() # 调整 medians 的形状以匹配 x 的形状 medians = medians.expand_as(x) # 调用基类的 compress 方法进行压缩 return super().compress(x, indexes, medians) def decompress(self, strings, size): # 预期的输出大小应包括单个通道 output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len # 构建索引 indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) # 获取中位数并调整其形状以匹配预期输出的形状 medians = self._extend_ndims(self._get_medians().detach(), len(size)) medians = medians.expand(len(strings), 1, *([-1] * len(size))) # 调用基类的 decompress 方法进行解压缩 return super().decompress(strings, indexes, medians.dtype, medians) def _preprocess(self, x): x = x.permute(0, 2, 3, 1).contiguous() return x class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) try: from apex.normalization import FusedLayerNorm except: FusedLayerNorm = LayerNorm class ImportancePredictor(nn.Module): """ Input: z_q: [b, (h*w), c] Output: importance_score: [b, N] """ def __init__(self, embed_dim=768): # 768 super().__init__() self.in_conv = nn.Sequential( FusedLayerNorm(embed_dim, eps=1e-5), nn.Linear(embed_dim, embed_dim), nn.GELU() ) self.out_conv = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Linear(embed_dim // 2, embed_dim // 4), nn.GELU(), nn.Linear(embed_dim // 4, 1) ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, policy): x = self.in_conv(x) B, N, C = x.size() local_x = x[:, :, :C // 2] global_x = (x[:, :, C // 2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True) x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1) x = self.out_conv(x) return x.squeeze(-1) # 将形状从 [b, N, 1] 转换为 [b, N] class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) with torch.cuda.amp.autocast(enabled=False): attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale attn = attn - torch.max(attn, dim=-1, keepdim=True)[0] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性 return x, attn class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0 self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4 self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1 def forward(self, x, return_attention=False): if return_attention: _, attn = self.attn(self.norm1(x)) return attn else: y, _ = self.attn(self.norm1(x)) x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class LabelSmoothingCrossEntropy(nn.Module): """ NLL loss with label smoothing. """ def __init__(self, smoothing=0.1): super(LabelSmoothingCrossEntropy, self).__init__() assert smoothing < 1.0 self.smoothing = smoothing self.confidence = 1. - smoothing def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: logprobs = torch.nn.functional.log_softmax(x, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) smooth_loss = -logprobs.mean(dim=-1) loss = self.confidence * nll_loss + self.smoothing * smooth_loss return loss class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size) self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6) self.dropout = nn.Dropout(dropout) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257) # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。 # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用 torch.nn.init.normal_(self.word_embeddings.weight, std=.02) torch.nn.init.normal_(self.position_embeddings.weight, std=.02) def forward( self, input_ids ): input_shape = input_ids.size() # input_ids: (B, N)(32,257) seq_length = input_shape[1] position_ids = self.position_ids[:, :seq_length] inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim) position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim) embeddings = inputs_embeds + position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class MlmLayer(nn.Module): def __init__(self, feat_emb_dim, word_emb_dim, vocab_size): super().__init__() self.fc = nn.Linear(feat_emb_dim, word_emb_dim) self.gelu = nn.GELU() self.ln = nn.LayerNorm(word_emb_dim) self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size)) def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim) mlm_hidden = self.fc(x) mlm_hidden = self.gelu(mlm_hidden) mlm_hidden = self.ln(mlm_hidden) word_embeddings = word_embeddings.transpose(0, 1) logits = torch.matmul(mlm_hidden, word_embeddings) logits = logits + self.bias return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率 class MaskedGenerativeEncoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, mask_ratio_min=0.5, mask_ratio_max=0.8, vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'): super().__init__() # -------------------------------------------------------------------------- # VQGAN specifics config = OmegaConf.load('config/vqgan.yaml').model self.vqgan = VQModel(ddconfig=config.params.ddconfig, n_embed=config.params.n_embed, # 1024 embed_dim=config.params.embed_dim, # 256 ckpt_path=vqgan_ckpt_path) for param in self.vqgan.parameters(): param.requires_grad = False self.codebook_size = config.params.n_embed # 1024 vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token. self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100 self.mask_token_label = vocab_size - 1 # 2024 self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token hidden_size=embed_dim, max_position_embeddings=img_size +1, # max_position_embeddings=256+1, # 256个patch + 1 class token dropout=0.1) # MAGE variant masking ratio self.mask_ratio_min = mask_ratio_min self.mask_ratio_max = mask_ratio_max # -------------------------------------------------------------------------- # MAGE encoder specifics dropout_rate = 0.1 self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024 num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.token_predictor = ImportancePredictor(config.params.embed_dim) # predict importance tokens self.blocks = nn.ModuleList([ # encoder Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop=dropout_rate, attn_drop=dropout_rate) for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768 self.norm = norm_layer(embed_dim) # layer norm # -------------------------------------------------------------------------- # MAGE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512 self.pad_with_cls_token = True self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop=dropout_rate, attn_drop=dropout_rate) for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch # -------------------------------------------------------------------------- self.token_all_mask_param = nn.Parameter(torch.zeros((1, 256))) # 假设最大序列长度为8192 # -------------------------------------------------------------------------- # MlmLayer self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size) self.norm_pix_loss = norm_pix_loss self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1) # -------------------------------------------------------------------------- self.entropy_bottleneck = FactorizedEntropyModel(1) self.initialize_weights() self.freeze_parameters() def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def random_sample_mask_rate(self): # 生成一个 (0, 1] 范围内的随机数 random_sample = 1 - torch.rand(1) # 映射到 mask_ratio_min 到 mask_ratio_max 的范围 mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min) return mask_rate.item() # 转换为Python的标量值 def get_cdf_token_mask(self, token_all_mask): bsz, seq_len = token_all_mask.size() # --- use Normal distribution. dist_normal = torch.distributions.Normal(0, 2) cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1)) cdf_mask_token = (cdf_mask_token - .5) * 2 cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp', b=bsz, s=seq_len) cdf_mask_token = F.pad(cdf_mask_token, (1, 0)) return cdf_mask_token def pre_encoding(self, x, is_training=False, manual_mask_rate=None): """ input: x: (B, 3, H, W) output: gt_indices, token_indices, decoding要用 unmaksed_token_indices, 算码率要用 token_all_mask, decoding要用 token_drop_mask, decoding要用 mask_ratio, 可视化 importance_scores 可视化· """ # ============ 1. tokenization ============ # with torch.no_grad(): z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16) _, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192) token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W) gt_indices = token_indices.clone().detach().long() # ============ 2. masking process ============ # bsz, seq_len = token_indices.size() # seq_len=h*w if is_training: mask_rate = self.random_sample_mask_rate() num_dropped_tokens = int(np.ceil(seq_len * 0.2)) else: num_dropped_tokens = 0 if manual_mask_rate is not None: mask_rate = manual_mask_rate else: raise ValueError("mask_rate should be provided for inference!") num_masked_tokens = int(np.ceil(seq_len * mask_rate)) mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda z_q = rearrange(z_q, 'b c h w -> b (h w) c') B, num_patches, _ = z_q.shape mask = torch.ones(B, num_patches, 1, dtype=z_q.dtype, device=z_q.device) importance_scores = self.token_predictor(z_q, mask) topk_scores, topk_indices = torch.topk(importance_scores, num_masked_tokens, dim=1, largest=False) # 创建 token_all_mask token_all_mask = torch.zeros_like(importance_scores, requires_grad=True) token_all_mask = token_all_mask.scatter(1, topk_indices, 1.0) noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1] sorted_noise, _ = torch.sort(noise, dim=1) if num_dropped_tokens > 0: cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens] else: cutoff_drop = torch.zeros((bsz, 1), device=x.device) token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的 # 获取unmasked token及其位置 unmasked_pos = token_all_mask == 0 # 未被mask的位置 unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, importance_scores def pre_decoding(self, gt_indices, token_indices, token_all_mask, token_drop_mask): bsz, seq_len = token_indices.size() padded_token_indices = torch.full_like(token_indices, fill_value=self.mask_token_label) token_indices = token_indices * (1 - token_all_mask) + padded_token_indices * token_all_mask # ============ 3. Adding class token ============ # # concate class token, add [CLS] token to aggregate sequence-level representations token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1) token_indices[:, 0] = self.fake_class_label # [B, 257] # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token, # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop") # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1) token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1) token_indices = token_indices.long() # ============ 4. Embedding and Dropout ============ # # bert embedding input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768] # print("Input embedding shape:", input_embeddings.shape) bsz, seq_len, emb_dim = input_embeddings.shape # dropping token_keep_mask = 1 - token_drop_mask input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim) # print("Input embedding after drop shape:", input_embeddings_after_drop.shape) # ============ 5. Transformer encoding ============ # x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768 # x = input_embeddings for blk in self.blocks: x = blk(x) # each block has a multi-head self-attention and a mlp x = self.norm(x) # print("Encoder representation shape:", x.shape) return x, token_indices, token_all_mask, token_drop_mask def forward_decoding(self, x, token_drop_mask, token_all_mask): """ x: output x of forward_encoder() token_drop_mask: positions for dropped tokens token_all_mask: positions for masked tokens """ # ============ 1. Prepare Embedding and padding tokens ============ # # embed tokens x = self.decoder_embed(x) # input_embedding_after_padding # append mask tokens to sequence # replicates the [CLS] token embedding across the sequence length where masking is to be applied if self.pad_with_cls_token: # True mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1) else: mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1) # ============ 2. Prepare positional embedding ============ # # put undropped tokens into original sequence x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去 x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) # set undropped but masked positions with mask x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding # add pos embed x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) word_embeddings = self.token_emb.word_embeddings.weight.data.detach() logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens # print("Logits shape:", x.shape) return logits def forward_loss(self, gt_indices, logits, mask): bsz, seq_len = gt_indices.size() # logits and mask are with seq_len+1 but gt_indices is with seq_len loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len)) loss = loss.reshape(bsz, seq_len) loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches return loss def cal_lmbda(self, mask_ratio, A=5e-1, B=8): lmbda = A * torch.exp(B * (1 - mask_ratio)) return lmbda def cal_loss(self, logits, gt_indices, mask, mask_ratio): mask_ratio = torch.tensor(mask_ratio) ## cal cross entropy loss task_loss = self.forward_loss(gt_indices, logits, mask) lmbda = self.cal_lmbda(mask_ratio) ## cal total loss for codec optimization return task_loss, lmbda def freeze_parameters(self): for name, param in self.named_parameters(): if 'token_predictor' not in name: param.requires_grad = False # self.token_all_mask_param.requires_grad = True def forward(self, imgs, is_training=False, manual_mask_rate=None): ## ---------- encoding process ---------- ## gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, im_scores = self.pre_encoding(imgs, is_training, manual_mask_rate) latent = latent.unsqueeze(1) _, latent_likelihoods = self.entropy_bottleneck(latent) mask_stream, mask_len = rle_encode(token_all_mask) mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=16, w=16).unsqueeze(1) im_scores_vis = rearrange(im_scores, 'b (h w) -> b h w', h=16, w=16).unsqueeze(1) ## ---------- decoding process ---------- ## # decoded_mask = rle_decode(mask_stream, token_all_mask.shape).float() # decoded_mask = decoded_mask.to(device=imgs.device) # latent_hat = latent_hat.squeeze(1) x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, token_indices, token_all_mask, token_drop_mask) logits = self.forward_decoding(x, token_drop_mask, token_all_mask) ## calculate loss task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio) return_dict = { 'logits': logits, 'likelihoods': latent_likelihoods, 'task_loss': task_loss, 'token_indices': token_indices, 'token_all_mask': token_all_mask, 'mask_len': mask_len, # 'bs_mask_token': bs_mask_token, 'mask_ratio': mask_ratio, # 'lambda': lmbda, 'mask_vis': 1 - mask_vis, 'im_score_vis': im_scores_vis, } return return_dict def gen_img(self, logits, token_all_mask, token_indices, num_iter=6, choice_temperature=4.5): """ generated image at inference seed: random seed logits: predicted logits by model decoder token_all_mask: mask token indices token_indices: token indices of the input image after the vq tokenizer num_iter: number of iterations for sampling choice_temperature: temperature for sampling """ # torch.manual_seed(seed) # np.random.seed(seed) bsz = logits.size(0) codebook_emb_dim = 256 codebook_size = 1024 mask_token_id = self.mask_token_label _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float() for step in range(num_iter): if step == 0: cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens) cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列 logits = logits[:, 1:, :codebook_size] # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0. # For iter=6, we use categorical sampling and temp=4.5." sample_dist = torch.distributions.categorical.Categorical(logits=logits) sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1) # get ids for next step # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。 # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。 unknown_map = (cur_ids == mask_token_id) sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) # Defines the mask ratio for the next round. The number to mask out is # determined by mask_ratio * unknown_number_in_the_beginning. ratio = 1. * (step + 1) / num_iter mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter) # sample ids according to prediction confidence probs = torch.nn.functional.softmax(logits, dim=-1) selected_probs = torch.squeeze( torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float() unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda() mask_ratio = torch.tensor(mask_ratio).cuda() # mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda() mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数 # Keeps at least one of prediction in this round and also masks out at least # one and for the next iteration mask_len = torch.maximum(torch.Tensor([1]).cuda(), torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len)) # Sample masking tokens for next iteration masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio)) # Masks tokens with lower confidence. token_indices = torch.where(masking, mask_token_id, sampled_ids) else: cur_ids = token_indices.clone().long() # .long(): to int64 token_indices = torch.cat( [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1) token_indices[:, 0] = self.fake_class_label token_indices = token_indices.long() token_all_mask = token_indices == mask_token_id token_drop_mask = torch.zeros_like(token_indices) # token embedding input_embeddings = self.token_emb(token_indices) # get input embeddings # encoder x = input_embeddings for blk in self.blocks: x = blk(x) x = self.norm(x) # decoder logits = self.forward_decoding(x, token_drop_mask, token_all_mask) logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size # get token prediction # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0. # For iter=6, we use categorical sampling and temp=4.5." sample_dist = torch.distributions.categorical.Categorical(logits=logits) sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1) # get ids for next step # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。 # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。 unknown_map = (cur_ids == mask_token_id) sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) # Defines the mask ratio for the next round. The number to mask out is # determined by mask_ratio * unknown_number_in_the_beginning. ratio = 1. * (step + 1) / num_iter mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter) # sample ids according to prediction confidence probs = torch.nn.functional.softmax(logits, dim=-1) selected_probs = torch.squeeze( torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float() unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda() mask_ratio = torch.tensor(mask_ratio).cuda() mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数 # Keeps at least one of prediction in this round and also masks out at least # one and for the next iteration mask_len = torch.maximum(torch.Tensor([1]).cuda(), torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len)) # Sample masking tokens for next iteration masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio)) # Masks tokens with lower confidence. token_indices = torch.where(masking, mask_token_id, sampled_ids) # vqgan visualization z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim)) gen_images = self.vqgan.decode(z_q) return gen_images def mage_vit_base_patch16(**kwargs): model = MaskedGenerativeEncoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mage_vit_large_patch16(**kwargs): model = MaskedGenerativeEncoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model