mimc_rl / models_mage_codec.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
from util.pos_embed import get_2d_sincos_pos_embed
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
import numpy as np
import scipy.stats as stats
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import conv3x3, subpel_conv3x3
import math
from torch import Tensor
from einops import rearrange, repeat
import torch.nn.functional as F
import torchac
from typing import Any, Callable, List, Optional, Tuple, Union
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
def ste_round(x: Tensor) -> Tensor:
return torch.round(x) - x.detach() + x
def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
def mask_by_random_topk(mask_len, probs, temperature=1.0):
mask_len = mask_len.squeeze()
# 使用Gumbel分布进行采样,增加随机性
confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
sorted_confidence, _ = torch.sort(confidence, axis=-1)
# Obtains cut off threshold given the mask lengths.
cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
# Masks tokens with lower confidence.
masking = (confidence <= cut_off)
return masking
def adjust_mask_and_drop_embeddings(token_keep_mask):
"""
Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
some of them to False, and then applies this adjusted mask to input_embeddings.
Parameters:
- input_embeddings: Tensor, The embeddings tensor.
- token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
Returns:
- Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
"""
# 获取非零(即值为True)元素的索引
non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
# 计算非零元素的数量
non_zero_count = non_zero_indices[0].size(0)
# 计算最近的整数平方倍
next_square = math.floor(math.sqrt(non_zero_count))**2
# 计算需要移除的元素数量
remove_count = non_zero_count - next_square
if remove_count > 0:
# 如果需要移除元素以达到整数平方倍
permuted_indices = torch.randperm(non_zero_count)[:remove_count]
for idx in permuted_indices:
token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
# 使用更新后的token_keep_mask
# input_embeddings_after_drop = input_embeddings[token_keep_mask]
return token_keep_mask
class FactorizedEntropyModel(EntropyBottleneck):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
if training is None:
training = self.training
# 输入形状已经是 [b, c, seq_len],无需转置
shape = x.size()
# Add noise or quantize
means = self._get_medians()
# outputs = self.quantize(
# x, "noise" if training else "dequantize", means.long()
# )
outputs = self.quantize(
x, "dequantize", means.long()
)
if not torch.jit.is_scripting():
likelihood = self._likelihood(outputs)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
else:
raise NotImplementedError("TorchScript is not yet supported")
return outputs, likelihood
def compress(self, x):
# 构建索引,适用于单通道序列数据
indexes = self._build_indexes(x.size())
# 获取中位数,已经适配为单通道
medians = self._get_medians().detach()
# 调整 medians 的形状以匹配 x 的形状
medians = medians.expand_as(x)
# 调用基类的 compress 方法进行压缩
return super().compress(x, indexes, medians)
def decompress(self, strings, size):
# 预期的输出大小应包括单个通道
output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
# 构建索引
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
# 获取中位数并调整其形状以匹配预期输出的形状
medians = self._extend_ndims(self._get_medians().detach(), len(size))
medians = medians.expand(len(strings), 1, *([-1] * len(size)))
# 调用基类的 decompress 方法进行解压缩
return super().decompress(strings, indexes, medians.dtype, medians)
def _preprocess(self, x):
x = x.permute(0, 2, 3, 1).contiguous()
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
with torch.cuda.amp.autocast(enabled=False):
attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
# x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
return x, attn
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
def forward(self, x, return_attention=False):
if return_attention:
_, attn = self.attn(self.norm1(x))
return attn
else:
y, _ = self.attn(self.norm1(x))
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class LabelSmoothingCrossEntropy(nn.Module):
""" NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
self.dropout = nn.Dropout(dropout)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
# 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
# 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
def forward(
self, input_ids
):
input_shape = input_ids.size() # input_ids: (B, N)(32,257)
seq_length = input_shape[1]
position_ids = self.position_ids[:, :seq_length]
inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class MlmLayer(nn.Module):
def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
super().__init__()
self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
self.gelu = nn.GELU()
self.ln = nn.LayerNorm(word_emb_dim)
self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
mlm_hidden = self.fc(x)
mlm_hidden = self.gelu(mlm_hidden)
mlm_hidden = self.ln(mlm_hidden)
word_embeddings = word_embeddings.transpose(0, 1)
logits = torch.matmul(mlm_hidden, word_embeddings)
logits = logits + self.bias
return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
class MaskedGenerativeEncoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
super().__init__()
# --------------------------------------------------------------------------
# VQGAN specifics
config = OmegaConf.load('config/vqgan.yaml').model
self.vqgan = VQModel(ddconfig=config.params.ddconfig,
n_embed=config.params.n_embed, # 1024
embed_dim=config.params.embed_dim, # 256
ckpt_path=vqgan_ckpt_path)
for param in self.vqgan.parameters():
param.requires_grad = False
self.codebook_size = config.params.n_embed # 1024
vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
self.mask_token_label = vocab_size - 1 # 2024
self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
hidden_size=embed_dim,
max_position_embeddings=img_size +1,
# max_position_embeddings=256+1, # 256个patch + 1 class token
dropout=0.1)
# MAGE variant masking ratio
self.mask_ratio_min = mask_ratio_min
self.mask_ratio_max = mask_ratio_max
# self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
# (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
# loc=mask_ratio_mu, scale=mask_ratio_std)
# --------------------------------------------------------------------------
# MAGE encoder specifics
dropout_rate = 0.1
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([ # encoder
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
drop=dropout_rate, attn_drop=dropout_rate)
for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
self.norm = norm_layer(embed_dim) # layer norm
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAGE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
self.pad_with_cls_token = True
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
drop=dropout_rate, attn_drop=dropout_rate)
for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MlmLayer
self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
self.norm_pix_loss = norm_pix_loss
self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
# --------------------------------------------------------------------------
self.entropy_bottleneck = FactorizedEntropyModel(1)
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def random_sample_mask_rate(self):
# 生成一个 (0, 1] 范围内的随机数
random_sample = 1 - torch.rand(1)
# 映射到 mask_ratio_min 到 mask_ratio_max 的范围
mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
return mask_rate.item() # 转换为Python的标量值
def get_cdf_token_mask(self, token_all_mask):
bsz, seq_len = token_all_mask.size()
# --- use Normal distribution.
dist_normal = torch.distributions.Normal(0, 2)
cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
cdf_mask_token = (cdf_mask_token - .5) * 2
cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
b=bsz, s=seq_len)
cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
return cdf_mask_token
# def get_cdf_token_mask(self, token_all_mask):
# bsz, seq_len = token_all_mask.size()
# # 直接生成一个0到1之间的线性空间
# linear_space = torch.linspace(0, 1, steps=seq_len+1)
# # 无需映射到-1到1
# cdf_mask_token = linear_space
# # 调整形状以匹配token_all_mask,并扩展到每个batch
# cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
# b=bsz, s=seq_len)
# # cdf_mask_token = cdf_mask_token.unsqueeze(0).unsqueeze(-1).repeat(bsz, 1, seq_len)
# # 添加填充以匹配原始代码的操作
# cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
# return cdf_mask_token
def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
"""
input: x: (B, 3, H, W)
"""
# ============ 1. tokenization ============ #
with torch.no_grad():
z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
_, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
gt_indices = token_indices.clone().detach().long()
# ============ 2. masking process ============ #
bsz, seq_len = token_indices.size() # seq_len=h*w
mask_ratio_min = self.mask_ratio_min # 0.5
if is_training:
# mask_rate = self.mask_ratio_generator.rvs(1)[0]
mask_rate = self.random_sample_mask_rate()
num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
else:
num_dropped_tokens = 0
if manual_mask_rate is not None:
mask_rate = manual_mask_rate
else:
raise ValueError("mask_rate should be provided for inference!")
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
while True:
noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
if num_dropped_tokens > 0:
cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
else:
cutoff_drop = torch.zeros((bsz, 1), device=x.device)
cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
break
else:
print("Rerandom the noise!")
# 获取unmasked token及其位置
unmasked_pos = token_all_mask == 0 # 未被mask的位置
unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio
def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
bsz, seq_len = gt_indices.size()
padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
# 将未被mask的token填充回去
# 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
unmasked_token_counter = [0 for _ in range(bsz)]
for b in range(bsz):
for idx in range(seq_len):
# 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
if (token_all_mask[b, idx] == 0): # 检查是否未被mask
# 替换相应的unmaksed token
padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
# 更新计数器
unmasked_token_counter[b] += 1
token_indices = padded_token_indices
# ============ 3. Adding class token ============ #
# concate class token, add [CLS] token to aggregate sequence-level representations
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.fake_class_label # [B, 257]
# Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
# ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
# 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
token_indices = token_indices.long()
# ============ 4. Embedding and Dropout ============ #
# bert embedding
input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768]
# print("Input embedding shape:", input_embeddings.shape)
bsz, seq_len, emb_dim = input_embeddings.shape
# dropping
token_keep_mask = 1 - token_drop_mask
input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
# print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
# ============ 5. Transformer encoding ============ #
x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
for blk in self.blocks:
x = blk(x) # each block has a multi-head self-attention and a mlp
x = self.norm(x)
# print("Encoder representation shape:", x.shape)
return x, token_indices, token_all_mask, token_drop_mask
def forward_decoding(self, x, token_drop_mask, token_all_mask):
"""
x: output x of forward_encoder()
token_drop_mask: positions for dropped tokens
token_all_mask: positions for masked tokens
"""
# ============ 1. Prepare Embedding and padding tokens ============ #
# embed tokens
x = self.decoder_embed(x) # input_embedding_after_padding
# append mask tokens to sequence
# replicates the [CLS] token embedding across the sequence length where masking is to be applied
if self.pad_with_cls_token: # True
mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
else:
mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
# ============ 2. Prepare positional embedding ============ #
# put undropped tokens into original sequence
x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
# set undropped but masked positions with mask
x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
# add pos embed
x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
print("Logits shape:", logits.shape)
return logits
def forward_loss(self, gt_indices, logits, mask):
bsz, seq_len = gt_indices.size()
# logits and mask are with seq_len+1 but gt_indices is with seq_len
loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
loss = loss.reshape(bsz, seq_len)
loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
return loss
def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
lmbda = A * torch.exp(B * (1 - mask_ratio))
return lmbda
def cal_loss(self, logits, gt_indices, mask, mask_ratio):
mask_ratio = torch.tensor(mask_ratio)
## cal cross entropy loss
task_loss = self.forward_loss(gt_indices, logits, mask)
lmbda = self.cal_lmbda(mask_ratio)
## cal total loss for codec optimization
return task_loss, lmbda
def forward(self, imgs, is_training=False, manual_mask_rate=None):
## ---------- encoding process ---------- ##
gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio = self.pre_encoding(imgs, is_training, manual_mask_rate)
latent = latent.unsqueeze(1)
latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
# 判断latent_hat和latent是否相等
# print((latent_hat == latent).all())
cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
sym = (token_all_mask.short() + 1).cpu()
bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=16, w=16).unsqueeze(1)
## ---------- decoding process ---------- ##
decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
decoded_mask = (decoded_sym - 1).to(device=imgs.device)
latent_hat = latent_hat.squeeze(1)
x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
## calculate loss
task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
return_dict = {
'logits': logits,
'likelihoods': latent_likelihoods,
'task_loss': task_loss,
'token_indices': token_indices,
'token_all_mask': token_all_mask,
'bs_mask_token': bs_mask_token,
'mask_ratio': mask_ratio,
'lambda': lmbda,
'mask_vis': 1 - mask_vis,
}
return return_dict
# def update(self, scale_table=None, force=False):
# if scale_table is None:
# scale_table = get_scale_table()
# updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
# updated |= super().update(force=force)
# return updated
def gen_img(self, logits, token_all_mask, token_indices, num_iter=12, choice_temperature=4.5):
"""
generated image at inference
seed: random seed
logits: predicted logits by model decoder
token_all_mask: mask token indices
token_indices: token indices of the input image after the vq tokenizer
num_iter: number of iterations for sampling
choice_temperature: temperature for sampling
"""
# torch.manual_seed(seed)
# np.random.seed(seed)
bsz = logits.size(0)
codebook_emb_dim = 256
codebook_size = 1024
mask_token_id = self.mask_token_label
_CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
for step in range(num_iter):
if step == 0:
cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
logits = logits[:, 1:, :codebook_size]
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
# For iter=6, we use categorical sampling and temp=4.5."
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
# get ids for next step
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
unknown_map = (cur_ids == mask_token_id)
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
# Defines the mask ratio for the next round. The number to mask out is
# determined by mask_ratio * unknown_number_in_the_beginning.
ratio = 1. * (step + 1) / num_iter
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
# sample ids according to prediction confidence
probs = torch.nn.functional.softmax(logits, dim=-1)
selected_probs = torch.squeeze(
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
mask_ratio = torch.tensor(mask_ratio).cuda()
# mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
# Keeps at least one of prediction in this round and also masks out at least
# one and for the next iteration
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
# Sample masking tokens for next iteration
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
# Masks tokens with lower confidence.
token_indices = torch.where(masking, mask_token_id, sampled_ids)
else:
cur_ids = token_indices.clone().long() # .long(): to int64
token_indices = torch.cat(
[torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
token_indices[:, 0] = self.fake_class_label
token_indices = token_indices.long()
token_all_mask = token_indices == mask_token_id
token_drop_mask = torch.zeros_like(token_indices)
# token embedding
input_embeddings = self.token_emb(token_indices) # get input embeddings
# encoder
x = input_embeddings
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# decoder
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
# get token prediction
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
# For iter=6, we use categorical sampling and temp=4.5."
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
# get ids for next step
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
unknown_map = (cur_ids == mask_token_id)
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
# Defines the mask ratio for the next round. The number to mask out is
# determined by mask_ratio * unknown_number_in_the_beginning.
ratio = 1. * (step + 1) / num_iter
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
# sample ids according to prediction confidence
probs = torch.nn.functional.softmax(logits, dim=-1)
selected_probs = torch.squeeze(
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
mask_ratio = torch.tensor(mask_ratio).cuda()
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
# Keeps at least one of prediction in this round and also masks out at least
# one and for the next iteration
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
# Sample masking tokens for next iteration
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
# Masks tokens with lower confidence.
token_indices = torch.where(masking, mask_token_id, sampled_ids)
# vqgan visualization
z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim))
gen_images = self.vqgan.decode(z_q)
return gen_images
def mage_vit_base_patch16(**kwargs):
model = MaskedGenerativeEncoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mage_vit_large_patch16(**kwargs):
model = MaskedGenerativeEncoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model