Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import numpy as np | |
| from tqdm import tqdm | |
| import scipy.stats as stats | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch.utils.checkpoint import checkpoint | |
| from timm.models.vision_transformer import Block | |
| from .diffloss import DiffLoss | |
| def mask_by_order(mask_len, order, bsz, seq_len): | |
| masking = torch.zeros(bsz, seq_len).to(order.device) | |
| masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], | |
| src=torch.ones(bsz, seq_len).to(order.device)).bool() | |
| return masking | |
| class MAR(nn.Module): | |
| """ Masked Autoencoder with VisionTransformer backbone | |
| """ | |
| def __init__(self, img_size=256, vae_stride=16, patch_size=1, | |
| encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, | |
| decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, | |
| mlp_ratio=4., norm_layer=nn.LayerNorm, | |
| vae_embed_dim=16, | |
| mask_ratio_min=0.7, | |
| label_drop_prob=0.1, | |
| class_num=1000, | |
| attn_dropout=0.1, | |
| proj_dropout=0.1, | |
| buffer_size=64, | |
| diffloss_d=3, | |
| diffloss_w=1024, | |
| num_sampling_steps='100', | |
| diffusion_batch_mul=4, | |
| grad_checkpointing=False, | |
| ): | |
| super().__init__() | |
| # -------------------------------------------------------------------------- | |
| # VAE and patchify specifics | |
| self.vae_embed_dim = vae_embed_dim | |
| self.img_size = img_size | |
| self.vae_stride = vae_stride | |
| self.patch_size = patch_size | |
| self.seq_h = self.seq_w = img_size // vae_stride // patch_size | |
| self.seq_len = self.seq_h * self.seq_w | |
| self.token_embed_dim = vae_embed_dim * patch_size**2 | |
| self.grad_checkpointing = grad_checkpointing | |
| # -------------------------------------------------------------------------- | |
| # Class Embedding | |
| self.num_classes = class_num | |
| self.class_emb = nn.Embedding(class_num, encoder_embed_dim) | |
| self.label_drop_prob = label_drop_prob | |
| # Fake class embedding for CFG's unconditional generation | |
| self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim)) | |
| # -------------------------------------------------------------------------- | |
| # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25 | |
| self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25) | |
| # -------------------------------------------------------------------------- | |
| # MAR encoder specifics | |
| self.encoder_embed_dim = encoder_embed_dim | |
| self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True) | |
| self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6) | |
| self.buffer_size = buffer_size | |
| self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim)) | |
| self.encoder_blocks = nn.ModuleList([ | |
| Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, | |
| proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)]) | |
| self.encoder_norm = norm_layer(encoder_embed_dim) | |
| # -------------------------------------------------------------------------- | |
| # MAR decoder specifics | |
| self.decoder_embed_dim = decoder_embed_dim | |
| self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
| self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) | |
| self.decoder_blocks = nn.ModuleList([ | |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, | |
| proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) | |
| self.decoder_norm = norm_layer(decoder_embed_dim) | |
| self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) | |
| self.initialize_weights() | |
| # -------------------------------------------------------------------------- | |
| # Diffusion Loss | |
| self.diffloss = DiffLoss( | |
| target_channels=self.token_embed_dim, | |
| z_channels=decoder_embed_dim, | |
| width=diffloss_w, | |
| depth=diffloss_d, | |
| num_sampling_steps=num_sampling_steps, | |
| grad_checkpointing=self.grad_checkpointing | |
| ) | |
| self.diffusion_batch_mul = diffusion_batch_mul | |
| def get_encoder_pos_embed(self, h, w): | |
| if h == self.seq_h and w == self.seq_w: | |
| return self.encoder_pos_embed_learned | |
| buffer_pe, image_pe = self.encoder_pos_embed_learned.split( | |
| [self.buffer_size, self.seq_len], dim=1) | |
| image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', | |
| h=self.seq_h, w=self.seq_w) | |
| image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') | |
| image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') | |
| return torch.cat([buffer_pe, image_pe], dim=1) | |
| def get_decoder_pos_embed(self, h, w): | |
| if h == self.seq_h and w == self.seq_w: | |
| return self.decoder_pos_embed_learned | |
| buffer_pe, image_pe = self.decoder_pos_embed_learned.split( | |
| [self.buffer_size, self.seq_len], dim=1) | |
| image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', | |
| h=self.seq_h, w=self.seq_w) | |
| image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') | |
| image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') | |
| return torch.cat([buffer_pe, image_pe], dim=1) | |
| def get_diffusion_pos_embed(self, h, w): | |
| if h == self.seq_h and w == self.seq_w: | |
| return self.diffusion_pos_embed_learned | |
| image_pe = self.diffusion_pos_embed_learned | |
| image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', | |
| h=self.seq_h, w=self.seq_w) | |
| image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') | |
| image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') | |
| return image_pe | |
| def initialize_weights(self): | |
| # parameters | |
| torch.nn.init.normal_(self.class_emb.weight, std=.02) | |
| torch.nn.init.normal_(self.fake_latent, std=.02) | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02) | |
| torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) | |
| torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| def device(self): | |
| return self.fake_latent.data.device | |
| def dtype(self): | |
| return self.fake_latent.data.dtype | |
| def patchify(self, x): | |
| bsz, c, h, w = x.shape | |
| p = self.patch_size | |
| h_, w_ = h // p, w // p | |
| x = x.reshape(bsz, c, h_, p, w_, p) | |
| x = torch.einsum('nchpwq->nhwcpq', x) | |
| x = x.reshape(bsz, h_ * w_, c * p ** 2) | |
| return x # [n, l, d] | |
| def unpatchify(self, x): | |
| bsz = x.shape[0] | |
| p = self.patch_size | |
| c = self.vae_embed_dim | |
| h_, w_ = self.seq_h, self.seq_w | |
| x = x.reshape(bsz, h_, w_, c, p, p) | |
| x = torch.einsum('nhwcpq->nchpwq', x) | |
| x = x.reshape(bsz, c, h_ * p, w_ * p) | |
| return x # [n, c, h, w] | |
| def sample_orders(self, bsz, seq_len=None): | |
| if seq_len is None: | |
| seq_len = self.seq_len | |
| # generate a batch of random generation orders | |
| orders = [] | |
| for _ in range(bsz): | |
| order = np.array(list(range(seq_len))) | |
| np.random.shuffle(order) | |
| orders.append(order) | |
| orders = torch.Tensor(np.array(orders)).to(self.device).long() | |
| return orders | |
| def random_masking(self, x, orders): | |
| # generate token mask | |
| bsz, seq_len, embed_dim = x.shape | |
| assert seq_len == orders.shape[1] | |
| mask_rate = self.mask_ratio_generator.rvs(1)[0] | |
| num_masked_tokens = int(np.ceil(seq_len * mask_rate)) | |
| mask = torch.zeros(bsz, seq_len, device=x.device) | |
| mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], | |
| src=torch.ones(bsz, seq_len, device=x.device)) | |
| return mask | |
| def forward_mae_encoder(self, x, mask, class_embedding, image_shape=None): | |
| x = x.to(self.dtype) | |
| x = self.z_proj(x) | |
| bsz, seq_len, embed_dim = x.shape | |
| # concat buffer | |
| x = torch.cat([x.new_zeros(bsz, self.buffer_size, embed_dim), x], dim=1) | |
| mask_with_buffer = torch.cat([mask.new_zeros(x.size(0), self.buffer_size), mask], dim=1) | |
| # random drop class embedding during training | |
| # if self.training: | |
| # drop_latent_mask = torch.rand(bsz) < self.label_drop_prob | |
| # drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(self.device).to(x.dtype) | |
| # class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding | |
| x[:, :self.buffer_size] = class_embedding.view(bsz, -1, embed_dim) | |
| # encoder position embedding | |
| # x = x + self.encoder_pos_embed_learned | |
| if image_shape is None: | |
| x = x + self.encoder_pos_embed_learned | |
| else: | |
| h, w = image_shape | |
| assert h * w == seq_len | |
| x = x + self.get_encoder_pos_embed(h=h, w=w) | |
| # import pdb; pdb.set_trace() | |
| x = self.z_proj_ln(x) | |
| # dropping | |
| x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim) | |
| # apply Transformer blocks | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| for block in self.encoder_blocks: | |
| x = checkpoint(block, x, | |
| use_reentrant=False | |
| ) | |
| else: | |
| for block in self.encoder_blocks: | |
| x = block(x) | |
| x = self.encoder_norm(x) | |
| return x | |
| def forward_mae_decoder(self, x, mask, image_shape=None, x_con=None): | |
| bsz, seq_len = mask.shape | |
| x = self.decoder_embed(x) | |
| mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) | |
| # pad mask tokens | |
| mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) | |
| if x_con is not None: | |
| x_after_pad = self.decoder_embed(x_con) | |
| else: | |
| x_after_pad = mask_tokens.clone() | |
| x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) | |
| # decoder position embedding | |
| # x = x_after_pad + self.decoder_pos_embed_learned | |
| if image_shape is None: | |
| x = x_after_pad + self.decoder_pos_embed_learned | |
| else: | |
| h, w = image_shape | |
| assert h * w == seq_len | |
| x = x_after_pad + self.get_decoder_pos_embed(h=h, w=w) | |
| # apply Transformer blocks | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| for block in self.decoder_blocks: | |
| x = checkpoint(block, x, | |
| # use_reentrant=False | |
| ) | |
| else: | |
| for block in self.decoder_blocks: | |
| x = block(x) | |
| x = self.decoder_norm(x) | |
| x = x[:, self.buffer_size:] | |
| # x = x + self.diffusion_pos_embed_learned | |
| if image_shape is None: | |
| x = x + self.diffusion_pos_embed_learned | |
| else: | |
| h, w = image_shape | |
| assert h * w == seq_len | |
| x = x + self.get_diffusion_pos_embed(h=h, w=w) | |
| return x | |
| def mae_decoder_prepare(self, x, mask): | |
| x = self.decoder_embed(x) | |
| mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) | |
| # pad mask tokens | |
| mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) | |
| x_after_pad = mask_tokens.clone() | |
| x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) | |
| # decoder position embedding | |
| x = x_after_pad + self.decoder_pos_embed_learned | |
| return x | |
| def mae_decoder_forward(self, x): | |
| # apply Transformer blocks | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| for block in self.decoder_blocks: | |
| x = checkpoint(block, x, | |
| # use_reentrant=False | |
| ) | |
| else: | |
| for block in self.decoder_blocks: | |
| x = block(x) | |
| x = self.decoder_norm(x) | |
| x = x[:, self.buffer_size:] | |
| x = x + self.diffusion_pos_embed_learned | |
| return x | |
| def forward_loss(self, z, target, mask): | |
| bsz, seq_len, _ = target.shape | |
| target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
| z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
| mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) | |
| loss = self.diffloss(z=z, target=target, mask=mask) | |
| return loss | |
| def forward(self, imgs, labels): | |
| # class embed | |
| class_embedding = self.class_emb(labels) | |
| # patchify and mask (drop) tokens | |
| x = self.patchify(imgs) | |
| gt_latents = x.clone().detach() | |
| orders = self.sample_orders(bsz=x.size(0)) | |
| mask = self.random_masking(x, orders) | |
| # mae encoder | |
| x = self.forward_mae_encoder(x, mask, class_embedding) | |
| # mae decoder | |
| z = self.forward_mae_decoder(x, mask) | |
| # diffloss | |
| loss = self.forward_loss(z=z, target=gt_latents, mask=mask) | |
| return loss | |
| def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): | |
| import pdb; pdb.set_trace() | |
| # init and sample generation orders | |
| mask = torch.ones(bsz, self.seq_len).to(self.device) | |
| tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(self.device) | |
| orders = self.sample_orders(bsz) | |
| indices = list(range(num_iter)) | |
| if progress: | |
| indices = tqdm(indices) | |
| # generate latents | |
| for step in indices: | |
| cur_tokens = tokens.clone() | |
| # class embedding and CFG | |
| if labels is not None: | |
| class_embedding = self.class_emb(labels) | |
| else: | |
| class_embedding = self.fake_latent.repeat(bsz, 1) | |
| if not cfg == 1.0: | |
| tokens = torch.cat([tokens, tokens], dim=0) | |
| class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0) | |
| mask = torch.cat([mask, mask], dim=0) | |
| # mae encoder | |
| x = self.forward_mae_encoder(tokens, mask.to(self.dtype), class_embedding) | |
| # mae decoder | |
| z = self.forward_mae_decoder(x, mask.to(self.dtype)) | |
| import pdb; pdb.set_trace() | |
| # mask ratio for the next round, following MaskGIT and MAGE. | |
| mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) | |
| mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(self.device) | |
| import pdb; pdb.set_trace() | |
| # masks out at least one for the next iteration | |
| mask_len = torch.maximum(torch.Tensor([1]).to(self.device), | |
| torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) | |
| import pdb; pdb.set_trace() | |
| # get masking for next iteration and locations to be predicted in this iteration | |
| mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) | |
| import pdb; pdb.set_trace() | |
| if step >= num_iter - 1: | |
| mask_to_pred = mask[:bsz].bool() | |
| else: | |
| mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) | |
| mask = mask_next | |
| if not cfg == 1.0: | |
| mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) | |
| import pdb; pdb.set_trace() | |
| # sample token latents for this step | |
| z = z[mask_to_pred.nonzero(as_tuple=True)] | |
| # cfg schedule follow Muse | |
| if cfg_schedule == "linear": | |
| cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len | |
| elif cfg_schedule == "constant": | |
| cfg_iter = cfg | |
| else: | |
| raise NotImplementedError | |
| sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) | |
| if not cfg == 1.0: | |
| sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples | |
| mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) | |
| import pdb; pdb.set_trace() | |
| cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent | |
| tokens = cur_tokens.clone() | |
| # unpatchify | |
| tokens = self.unpatchify(tokens) | |
| return tokens | |
| def gradient_checkpointing_enable(self): | |
| self.grad_checkpointing = True | |
| def gradient_checkpointing_disable(self): | |
| self.grad_checkpointing = False | |
| def mar_base(**kwargs): | |
| model = MAR( | |
| encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, | |
| decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mar_large(**kwargs): | |
| model = MAR( | |
| encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, | |
| decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mar_huge(**kwargs): | |
| model = MAR( | |
| encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16, | |
| decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mar_max(**kwargs): | |
| model = MAR( | |
| encoder_embed_dim=1536, encoder_depth=24, encoder_num_heads=16, | |
| decoder_embed_dim=1536, decoder_depth=24, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model |