| import torch |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
| from timm.models.vision_transformer import Block |
| from functools import partial |
|
|
|
|
| class MARDecoder(nn.Module): |
| """ Masked Autoencoder with VisionTransformer backbone |
| """ |
| def __init__(self, img_size=256, vae_stride=16, |
| patch_size=1, |
| |
| decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, |
| mlp_ratio=4., |
| attn_dropout=0.1, |
| proj_dropout=0.1, |
| buffer_size=64, |
| grad_checkpointing=False, |
| ): |
| super().__init__() |
|
|
| |
| |
| self.img_size = img_size |
| self.vae_stride = vae_stride |
|
|
| self.seq_h = self.seq_w = img_size // vae_stride // patch_size |
| self.seq_len = self.seq_h * self.seq_w |
|
|
| self.grad_checkpointing = grad_checkpointing |
|
|
| |
| |
| self.buffer_size = buffer_size |
| |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) |
| self.decoder_blocks = nn.ModuleList([ |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) |
|
|
| self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6) |
| self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) |
|
|
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
|
|
| torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
| torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) |
| torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| |
| torch.nn.init.xavier_uniform_(m.weight) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| if m.weight is not None: |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def forward(self, x, mask): |
|
|
| |
| mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) |
|
|
| |
| mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) |
| x_after_pad = mask_tokens.clone() |
| x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
|
|
| |
| x = x_after_pad + self.decoder_pos_embed_learned |
|
|
| |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| for block in self.decoder_blocks: |
| x = checkpoint(block, x) |
| else: |
| for block in self.decoder_blocks: |
| x = block(x) |
| x = self.decoder_norm(x) |
|
|
| x = x[:, self.buffer_size:] |
| x = x + self.diffusion_pos_embed_learned |
| return x |
|
|
| def gradient_checkpointing_enable(self): |
| self.grad_checkpointing = True |
|
|
| def gradient_checkpointing_disable(self): |
| self.grad_checkpointing = False |
|
|