| | from functools import partial |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| | import scipy.stats as stats |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch.utils.checkpoint import checkpoint |
| | from timm.models.vision_transformer import Block |
| |
|
| | from .diffloss import DiffLoss |
| |
|
| |
|
| | def mask_by_order(mask_len, order, bsz, seq_len): |
| | masking = torch.zeros(bsz, seq_len).to(order.device) |
| | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], |
| | src=torch.ones(bsz, seq_len).to(order.device)).bool() |
| | return masking |
| |
|
| |
|
| | class MAR(nn.Module): |
| | """ Masked Autoencoder with VisionTransformer backbone |
| | """ |
| | def __init__(self, img_size=256, vae_stride=16, patch_size=1, |
| | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, |
| | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, |
| | mlp_ratio=4., norm_layer=nn.LayerNorm, |
| | vae_embed_dim=16, |
| | mask_ratio_min=0.7, |
| | label_drop_prob=0.1, |
| | class_num=1000, |
| | attn_dropout=0.1, |
| | proj_dropout=0.1, |
| | buffer_size=64, |
| | diffloss_d=3, |
| | diffloss_w=1024, |
| | num_sampling_steps='100', |
| | diffusion_batch_mul=4, |
| | grad_checkpointing=False, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.vae_embed_dim = vae_embed_dim |
| |
|
| | self.img_size = img_size |
| | self.vae_stride = vae_stride |
| | self.patch_size = patch_size |
| | self.seq_h = self.seq_w = img_size // vae_stride // patch_size |
| | self.seq_len = self.seq_h * self.seq_w |
| | self.token_embed_dim = vae_embed_dim * patch_size**2 |
| | self.grad_checkpointing = grad_checkpointing |
| |
|
| | |
| | |
| | self.num_classes = class_num |
| | self.class_emb = nn.Embedding(class_num, encoder_embed_dim) |
| | self.label_drop_prob = label_drop_prob |
| | |
| | self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim)) |
| |
|
| | |
| | |
| | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25) |
| |
|
| | |
| | |
| | self.encoder_embed_dim = encoder_embed_dim |
| | self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True) |
| | self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6) |
| | self.buffer_size = buffer_size |
| | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim)) |
| |
|
| | self.encoder_blocks = nn.ModuleList([ |
| | Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, |
| | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)]) |
| | self.encoder_norm = norm_layer(encoder_embed_dim) |
| |
|
| | |
| | |
| | self.decoder_embed_dim = decoder_embed_dim |
| | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) |
| |
|
| | self.decoder_blocks = nn.ModuleList([ |
| | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, |
| | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) |
| |
|
| | self.decoder_norm = norm_layer(decoder_embed_dim) |
| | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) |
| |
|
| | self.initialize_weights() |
| |
|
| | |
| | |
| | self.diffloss = DiffLoss( |
| | target_channels=self.token_embed_dim, |
| | z_channels=decoder_embed_dim, |
| | width=diffloss_w, |
| | depth=diffloss_d, |
| | num_sampling_steps=num_sampling_steps, |
| | grad_checkpointing=self.grad_checkpointing |
| | ) |
| | self.diffusion_batch_mul = diffusion_batch_mul |
| |
|
| | def get_encoder_pos_embed(self, h, w): |
| | if h == self.seq_h and w == self.seq_w: |
| | return self.encoder_pos_embed_learned |
| | buffer_pe, image_pe = self.encoder_pos_embed_learned.split( |
| | [self.buffer_size, self.seq_len], dim=1) |
| | image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', |
| | h=self.seq_h, w=self.seq_w) |
| | image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') |
| | image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') |
| |
|
| | return torch.cat([buffer_pe, image_pe], dim=1) |
| |
|
| | def get_decoder_pos_embed(self, h, w): |
| | if h == self.seq_h and w == self.seq_w: |
| | return self.decoder_pos_embed_learned |
| | buffer_pe, image_pe = self.decoder_pos_embed_learned.split( |
| | [self.buffer_size, self.seq_len], dim=1) |
| | image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', |
| | h=self.seq_h, w=self.seq_w) |
| | image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') |
| | image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') |
| |
|
| | return torch.cat([buffer_pe, image_pe], dim=1) |
| |
|
| | def get_diffusion_pos_embed(self, h, w): |
| | if h == self.seq_h and w == self.seq_w: |
| | return self.diffusion_pos_embed_learned |
| | image_pe = self.diffusion_pos_embed_learned |
| | image_pe = rearrange(image_pe, 'b (h w) c -> b c h w', |
| | h=self.seq_h, w=self.seq_w) |
| | image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear') |
| | image_pe = rearrange(image_pe, 'b c h w -> b (h w) c') |
| |
|
| | return image_pe |
| |
|
| | def initialize_weights(self): |
| | |
| | torch.nn.init.normal_(self.class_emb.weight, std=.02) |
| | torch.nn.init.normal_(self.fake_latent, std=.02) |
| | torch.nn.init.normal_(self.mask_token, std=.02) |
| | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02) |
| | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) |
| | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | if m.weight is not None: |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | @property |
| | def device(self): |
| | return self.fake_latent.data.device |
| |
|
| | @property |
| | def dtype(self): |
| | return self.fake_latent.data.dtype |
| |
|
| | def patchify(self, x): |
| | bsz, c, h, w = x.shape |
| | p = self.patch_size |
| | h_, w_ = h // p, w // p |
| |
|
| | x = x.reshape(bsz, c, h_, p, w_, p) |
| | x = torch.einsum('nchpwq->nhwcpq', x) |
| | x = x.reshape(bsz, h_ * w_, c * p ** 2) |
| | return x |
| |
|
| | def unpatchify(self, x): |
| | bsz = x.shape[0] |
| | p = self.patch_size |
| | c = self.vae_embed_dim |
| | h_, w_ = self.seq_h, self.seq_w |
| |
|
| | x = x.reshape(bsz, h_, w_, c, p, p) |
| | x = torch.einsum('nhwcpq->nchpwq', x) |
| | x = x.reshape(bsz, c, h_ * p, w_ * p) |
| | return x |
| |
|
| | def sample_orders(self, bsz, seq_len=None): |
| | if seq_len is None: |
| | seq_len = self.seq_len |
| | |
| | orders = [] |
| | for _ in range(bsz): |
| | order = np.array(list(range(seq_len))) |
| | np.random.shuffle(order) |
| | orders.append(order) |
| | orders = torch.Tensor(np.array(orders)).to(self.device).long() |
| | return orders |
| |
|
| | def random_masking(self, x, orders): |
| | |
| | bsz, seq_len, embed_dim = x.shape |
| | assert seq_len == orders.shape[1] |
| | mask_rate = self.mask_ratio_generator.rvs(1)[0] |
| | num_masked_tokens = int(np.ceil(seq_len * mask_rate)) |
| | mask = torch.zeros(bsz, seq_len, device=x.device) |
| | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], |
| | src=torch.ones(bsz, seq_len, device=x.device)) |
| | return mask |
| |
|
| | def forward_mae_encoder(self, x, mask, class_embedding, image_shape=None): |
| | x = x.to(self.dtype) |
| | x = self.z_proj(x) |
| | bsz, seq_len, embed_dim = x.shape |
| |
|
| | |
| | x = torch.cat([x.new_zeros(bsz, self.buffer_size, embed_dim), x], dim=1) |
| | mask_with_buffer = torch.cat([mask.new_zeros(x.size(0), self.buffer_size), mask], dim=1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | x[:, :self.buffer_size] = class_embedding.view(bsz, -1, embed_dim) |
| |
|
| | |
| | |
| | if image_shape is None: |
| | x = x + self.encoder_pos_embed_learned |
| | else: |
| | h, w = image_shape |
| | assert h * w == seq_len |
| | x = x + self.get_encoder_pos_embed(h=h, w=w) |
| | |
| | x = self.z_proj_ln(x) |
| |
|
| | |
| | x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim) |
| |
|
| | |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | for block in self.encoder_blocks: |
| | x = checkpoint(block, x, |
| | use_reentrant=False |
| | ) |
| | else: |
| | for block in self.encoder_blocks: |
| | x = block(x) |
| | x = self.encoder_norm(x) |
| |
|
| | return x |
| |
|
| | def forward_mae_decoder(self, x, mask, image_shape=None, x_con=None): |
| | bsz, seq_len = mask.shape |
| |
|
| | x = self.decoder_embed(x) |
| | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) |
| |
|
| | |
| | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) |
| |
|
| | if x_con is not None: |
| | x_after_pad = self.decoder_embed(x_con) |
| | else: |
| | x_after_pad = mask_tokens.clone() |
| | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
| |
|
| | |
| | |
| | if image_shape is None: |
| | x = x_after_pad + self.decoder_pos_embed_learned |
| | else: |
| | h, w = image_shape |
| | assert h * w == seq_len |
| | x = x_after_pad + self.get_decoder_pos_embed(h=h, w=w) |
| |
|
| | |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | for block in self.decoder_blocks: |
| | x = checkpoint(block, x, |
| | |
| | ) |
| | else: |
| | for block in self.decoder_blocks: |
| | x = block(x) |
| | x = self.decoder_norm(x) |
| |
|
| | x = x[:, self.buffer_size:] |
| | |
| | if image_shape is None: |
| | x = x + self.diffusion_pos_embed_learned |
| | else: |
| | h, w = image_shape |
| | assert h * w == seq_len |
| | x = x + self.get_diffusion_pos_embed(h=h, w=w) |
| | return x |
| |
|
| | def mae_decoder_prepare(self, x, mask): |
| | x = self.decoder_embed(x) |
| | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) |
| |
|
| | |
| | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) |
| | x_after_pad = mask_tokens.clone() |
| | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
| |
|
| | |
| | x = x_after_pad + self.decoder_pos_embed_learned |
| |
|
| | return x |
| |
|
| | def mae_decoder_forward(self, x): |
| | |
| | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| | for block in self.decoder_blocks: |
| | x = checkpoint(block, x, |
| | |
| | ) |
| | else: |
| | for block in self.decoder_blocks: |
| | x = block(x) |
| | x = self.decoder_norm(x) |
| |
|
| | x = x[:, self.buffer_size:] |
| | x = x + self.diffusion_pos_embed_learned |
| | return x |
| |
|
| | def forward_loss(self, z, target, mask): |
| | bsz, seq_len, _ = target.shape |
| | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) |
| | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) |
| | mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) |
| | loss = self.diffloss(z=z, target=target, mask=mask) |
| | return loss |
| |
|
| | def forward(self, imgs, labels): |
| |
|
| | |
| | class_embedding = self.class_emb(labels) |
| |
|
| | |
| | x = self.patchify(imgs) |
| | gt_latents = x.clone().detach() |
| | orders = self.sample_orders(bsz=x.size(0)) |
| | mask = self.random_masking(x, orders) |
| |
|
| | |
| | x = self.forward_mae_encoder(x, mask, class_embedding) |
| |
|
| | |
| | z = self.forward_mae_decoder(x, mask) |
| |
|
| | |
| | loss = self.forward_loss(z=z, target=gt_latents, mask=mask) |
| |
|
| | return loss |
| |
|
| | def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): |
| | import pdb; pdb.set_trace() |
| | |
| | mask = torch.ones(bsz, self.seq_len).to(self.device) |
| | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(self.device) |
| | orders = self.sample_orders(bsz) |
| |
|
| | indices = list(range(num_iter)) |
| | if progress: |
| | indices = tqdm(indices) |
| | |
| | for step in indices: |
| | cur_tokens = tokens.clone() |
| |
|
| | |
| | if labels is not None: |
| | class_embedding = self.class_emb(labels) |
| | else: |
| | class_embedding = self.fake_latent.repeat(bsz, 1) |
| | if not cfg == 1.0: |
| | tokens = torch.cat([tokens, tokens], dim=0) |
| | class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0) |
| | mask = torch.cat([mask, mask], dim=0) |
| |
|
| | |
| | x = self.forward_mae_encoder(tokens, mask.to(self.dtype), class_embedding) |
| |
|
| | |
| | z = self.forward_mae_decoder(x, mask.to(self.dtype)) |
| | import pdb; pdb.set_trace() |
| |
|
| | |
| | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) |
| | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(self.device) |
| | import pdb; pdb.set_trace() |
| | |
| | mask_len = torch.maximum(torch.Tensor([1]).to(self.device), |
| | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) |
| | import pdb; pdb.set_trace() |
| | |
| | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) |
| | import pdb; pdb.set_trace() |
| | if step >= num_iter - 1: |
| | mask_to_pred = mask[:bsz].bool() |
| | else: |
| | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) |
| | mask = mask_next |
| | if not cfg == 1.0: |
| | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) |
| | import pdb; pdb.set_trace() |
| | |
| | z = z[mask_to_pred.nonzero(as_tuple=True)] |
| | |
| | if cfg_schedule == "linear": |
| | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len |
| | elif cfg_schedule == "constant": |
| | cfg_iter = cfg |
| | else: |
| | raise NotImplementedError |
| | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) |
| | if not cfg == 1.0: |
| | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) |
| | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) |
| | import pdb; pdb.set_trace() |
| | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent |
| | tokens = cur_tokens.clone() |
| |
|
| | |
| | tokens = self.unpatchify(tokens) |
| | return tokens |
| |
|
| | def gradient_checkpointing_enable(self): |
| | self.grad_checkpointing = True |
| |
|
| | def gradient_checkpointing_disable(self): |
| | self.grad_checkpointing = False |
| |
|
| |
|
| | def mar_base(**kwargs): |
| | model = MAR( |
| | encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, |
| | decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mar_large(**kwargs): |
| | model = MAR( |
| | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, |
| | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mar_huge(**kwargs): |
| | model = MAR( |
| | encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16, |
| | decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|