| |
| |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| import torch |
| import torch.nn as nn |
| torch.backends.cuda.matmul.allow_tf32 = True |
| from functools import partial |
|
|
| from .blocks import Block, DecoderBlock, PatchEmbed |
| from .pos_embed import get_2d_sincos_pos_embed, RoPE2D |
| from .masking import RandomMask |
|
|
|
|
| class CroCoNet(nn.Module): |
|
|
| def __init__(self, |
| img_size=224, |
| patch_size=16, |
| mask_ratio=0.9, |
| enc_embed_dim=768, |
| enc_depth=12, |
| enc_num_heads=12, |
| dec_embed_dim=512, |
| dec_depth=8, |
| dec_num_heads=16, |
| mlp_ratio=4, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| norm_im2_in_dec=True, |
| pos_embed='cosine', |
| ): |
| |
| super(CroCoNet, self).__init__() |
|
|
| |
| self._set_patch_embed(img_size, patch_size, enc_embed_dim) |
|
|
| |
| self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) |
|
|
| self.pos_embed = pos_embed |
| if pos_embed=='cosine': |
| |
| enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) |
| self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float()) |
| |
| dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) |
| self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float()) |
| |
| self.rope = None |
| elif pos_embed.startswith('RoPE'): |
| self.enc_pos_embed = None |
| self.dec_pos_embed = None |
| if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") |
| freq = float(pos_embed[len('RoPE'):]) |
| self.rope = RoPE2D(freq=freq) |
| else: |
| raise NotImplementedError('Unknown pos_embed '+pos_embed) |
|
|
| |
| self.enc_depth = enc_depth |
| self.enc_embed_dim = enc_embed_dim |
| self.enc_blocks = nn.ModuleList([ |
| Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope) |
| for i in range(enc_depth)]) |
| self.enc_norm = norm_layer(enc_embed_dim) |
| |
| |
| self._set_mask_token(dec_embed_dim) |
|
|
| |
| self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec) |
| |
| |
| self._set_prediction_head(dec_embed_dim, patch_size) |
| |
| |
| self.initialize_weights() |
|
|
| def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
| self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) |
|
|
| def _set_mask_generator(self, num_patches, mask_ratio): |
| self.mask_generator = RandomMask(num_patches, mask_ratio) |
| |
| def _set_mask_token(self, dec_embed_dim): |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) |
| |
| def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): |
| self.dec_depth = dec_depth |
| self.dec_embed_dim = dec_embed_dim |
| |
| self.decoder_embed = nn.Linear(enc_embed_dim+0, dec_embed_dim, bias=True) |
| |
| self.dec_blocks = nn.ModuleList([ |
| DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) |
| for i in range(dec_depth)]) |
| |
| self.dec_norm = norm_layer(dec_embed_dim) |
| |
| def _set_prediction_head(self, dec_embed_dim, patch_size): |
| self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) |
| |
| |
| def initialize_weights(self): |
| |
| self.patch_embed._init_weights() |
| |
| if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02) |
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| |
| torch.nn.init.xavier_uniform_(m.weight) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| |
| def _encode_image(self, image, do_mask=False, return_all_blocks=False): |
| """ |
| image has B x 3 x img_size x img_size |
| do_mask: whether to perform masking or not |
| return_all_blocks: if True, return the features at the end of every block |
| instead of just the features from the last block (eg for some prediction heads) |
| """ |
| |
| |
| x, pos = self.patch_embed(image) |
| |
| if self.enc_pos_embed is not None: |
| x = x + self.enc_pos_embed[None,...] |
| |
| B,N,C = x.size() |
| if do_mask: |
| masks = self.mask_generator(x) |
| x = x[~masks].view(B, -1, C) |
| posvis = pos[~masks].view(B, -1, 2) |
| else: |
| B,N,C = x.size() |
| masks = torch.zeros((B,N), dtype=bool) |
| posvis = pos |
| |
| if return_all_blocks: |
| out = [] |
| for blk in self.enc_blocks: |
| x = blk(x, posvis) |
| out.append(x) |
| out[-1] = self.enc_norm(out[-1]) |
| return out, pos, masks |
| else: |
| for blk in self.enc_blocks: |
| x = blk(x, posvis) |
| x = self.enc_norm(x) |
| return x, pos, masks |
| |
| def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): |
| """ |
| return_all_blocks: if True, return the features at the end of every block |
| instead of just the features from the last block (eg for some prediction heads) |
| |
| masks1 can be None => assume image1 fully visible |
| """ |
| |
| visf1 = self.decoder_embed(feat1) |
| f2 = self.decoder_embed(feat2) |
| |
| B,Nenc,C = visf1.size() |
| if masks1 is None: |
| f1_ = visf1 |
| else: |
| Ntotal = masks1.size(1) |
| f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) |
| f1_[~masks1] = visf1.view(B * Nenc, C) |
| |
| if self.dec_pos_embed is not None: |
| f1_ = f1_ + self.dec_pos_embed |
| f2 = f2 + self.dec_pos_embed |
| |
| out = f1_ |
| out2 = f2 |
| if return_all_blocks: |
| _out, out = out, [] |
| for blk in self.dec_blocks: |
| _out, out2 = blk(_out, out2, pos1, pos2) |
| out.append(_out) |
| out[-1] = self.dec_norm(out[-1]) |
| else: |
| for blk in self.dec_blocks: |
| out, out2 = blk(out, out2, pos1, pos2) |
| out = self.dec_norm(out) |
| return out |
|
|
| def patchify(self, imgs): |
| """ |
| imgs: (B, 3, H, W) |
| x: (B, L, patch_size**2 *3) |
| """ |
| p = self.patch_embed.patch_size[0] |
| assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
| h = w = imgs.shape[2] // p |
| x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| x = torch.einsum('nchpwq->nhwpqc', x) |
| x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
| |
| return x |
|
|
| def unpatchify(self, x, channels=3): |
| """ |
| x: (N, L, patch_size**2 *channels) |
| imgs: (N, 3, H, W) |
| """ |
| patch_size = self.patch_embed.patch_size[0] |
| h = w = int(x.shape[1]**.5) |
| assert h * w == x.shape[1] |
| x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) |
| return imgs |
|
|
| def forward(self, img1, img2): |
| """ |
| img1: tensor of size B x 3 x img_size x img_size |
| img2: tensor of size B x 3 x img_size x img_size |
| |
| out will be B x N x (3*patch_size*patch_size) |
| masks are also returned as B x N just in case |
| """ |
| |
| feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) |
| |
| feat2, pos2, _ = self._encode_image(img2, do_mask=False) |
| |
| decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) |
| |
| out = self.prediction_head(decfeat) |
| |
| target = self.patchify(img1) |
| return out, mask1, target |
|
|