|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
from functools import partial |
|
|
|
|
|
from models.blocks import Block, DecoderBlock, PatchEmbed |
|
|
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D |
|
|
from models.masking import RandomMask |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
|
|
|
class CrocoConfig(PretrainedConfig): |
|
|
model_type = "croco" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
img_size=224, |
|
|
patch_size=16, |
|
|
mask_ratio=0.9, |
|
|
enc_embed_dim=768, |
|
|
enc_depth=12, |
|
|
enc_num_heads=12, |
|
|
dec_embed_dim=512, |
|
|
dec_depth=8, |
|
|
dec_num_heads=16, |
|
|
mlp_ratio=4, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
norm_im2_in_dec=True, |
|
|
pos_embed="cosine", |
|
|
): |
|
|
super().__init__() |
|
|
self.img_size = img_size |
|
|
self.patch_size = patch_size |
|
|
self.mask_ratio = mask_ratio |
|
|
self.enc_embed_dim = enc_embed_dim |
|
|
self.enc_depth = enc_depth |
|
|
self.enc_num_heads = enc_num_heads |
|
|
self.dec_embed_dim = dec_embed_dim |
|
|
self.dec_depth = dec_depth |
|
|
self.dec_num_heads = dec_num_heads |
|
|
self.mlp_ratio = mlp_ratio |
|
|
self.norm_layer = norm_layer |
|
|
self.norm_im2_in_dec = norm_im2_in_dec |
|
|
self.pos_embed = pos_embed |
|
|
|
|
|
|
|
|
class CroCoNet(PreTrainedModel): |
|
|
|
|
|
config_class = CrocoConfig |
|
|
base_model_prefix = "croco" |
|
|
|
|
|
def __init__(self, config: CrocoConfig): |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self._set_patch_embed(config.img_size, config.patch_size, config.enc_embed_dim) |
|
|
|
|
|
|
|
|
self._set_mask_generator(self.patch_embed.num_patches, config.mask_ratio) |
|
|
|
|
|
self.pos_embed = config.pos_embed |
|
|
if config.pos_embed == "cosine": |
|
|
|
|
|
enc_pos_embed = get_2d_sincos_pos_embed( |
|
|
config.enc_embed_dim, |
|
|
int(self.patch_embed.num_patches**0.5), |
|
|
n_cls_token=0, |
|
|
) |
|
|
self.register_buffer( |
|
|
"enc_pos_embed", torch.from_numpy(enc_pos_embed).float() |
|
|
) |
|
|
|
|
|
dec_pos_embed = get_2d_sincos_pos_embed( |
|
|
config.dec_embed_dim, |
|
|
int(self.patch_embed.num_patches**0.5), |
|
|
n_cls_token=0, |
|
|
) |
|
|
self.register_buffer( |
|
|
"dec_pos_embed", torch.from_numpy(dec_pos_embed).float() |
|
|
) |
|
|
|
|
|
self.rope = None |
|
|
elif config.pos_embed.startswith("RoPE"): |
|
|
self.enc_pos_embed = None |
|
|
self.dec_pos_embed = None |
|
|
if RoPE2D is None: |
|
|
raise ImportError( |
|
|
"Cannot find cuRoPE2D, please install it following the README instructions" |
|
|
) |
|
|
freq = float(config.pos_embed[len("RoPE") :]) |
|
|
self.rope = RoPE2D(freq=freq) |
|
|
else: |
|
|
raise NotImplementedError("Unknown pos_embed " + config.pos_embed) |
|
|
|
|
|
|
|
|
self.enc_depth = config.enc_depth |
|
|
self.enc_embed_dim = config.enc_embed_dim |
|
|
self.enc_blocks = nn.ModuleList( |
|
|
[ |
|
|
Block( |
|
|
config.enc_embed_dim, |
|
|
config.enc_num_heads, |
|
|
config.mlp_ratio, |
|
|
qkv_bias=True, |
|
|
norm_layer=config.norm_layer, |
|
|
rope=self.rope, |
|
|
) |
|
|
for i in range(config.enc_depth) |
|
|
] |
|
|
) |
|
|
self.enc_norm = config.norm_layer(config.enc_embed_dim) |
|
|
|
|
|
|
|
|
|
|
|
self.mask_token = None |
|
|
|
|
|
|
|
|
self._set_decoder( |
|
|
config.enc_embed_dim, |
|
|
config.dec_embed_dim, |
|
|
config.dec_num_heads, |
|
|
config.dec_depth, |
|
|
config.mlp_ratio, |
|
|
config.norm_layer, |
|
|
config.norm_im2_in_dec, |
|
|
) |
|
|
|
|
|
|
|
|
self._set_prediction_head(config.dec_embed_dim, config.patch_size) |
|
|
|
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) |
|
|
|
|
|
def _set_mask_generator(self, num_patches, mask_ratio): |
|
|
self.mask_generator = RandomMask(num_patches, mask_ratio) |
|
|
|
|
|
def _set_mask_token(self, dec_embed_dim): |
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) |
|
|
|
|
|
def _set_decoder( |
|
|
self, |
|
|
enc_embed_dim, |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
dec_depth, |
|
|
mlp_ratio, |
|
|
norm_layer, |
|
|
norm_im2_in_dec, |
|
|
): |
|
|
self.dec_depth = dec_depth |
|
|
self.dec_embed_dim = dec_embed_dim |
|
|
|
|
|
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) |
|
|
|
|
|
self.dec_blocks = nn.ModuleList( |
|
|
[ |
|
|
DecoderBlock( |
|
|
dec_embed_dim, |
|
|
dec_num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=True, |
|
|
norm_layer=norm_layer, |
|
|
norm_mem=norm_im2_in_dec, |
|
|
rope=self.rope, |
|
|
) |
|
|
for i in range(dec_depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.dec_norm = norm_layer(dec_embed_dim) |
|
|
|
|
|
def _set_prediction_head(self, dec_embed_dim, patch_size): |
|
|
self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
self.patch_embed._init_weights() |
|
|
|
|
|
if self.mask_token is not None: |
|
|
torch.nn.init.normal_(self.mask_token, std=0.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def _encode_image(self, image, do_mask=False, return_all_blocks=False): |
|
|
""" |
|
|
image has B x 3 x img_size x img_size |
|
|
do_mask: whether to perform masking or not |
|
|
return_all_blocks: if True, return the features at the end of every block |
|
|
instead of just the features from the last block (eg for some prediction heads) |
|
|
""" |
|
|
|
|
|
|
|
|
x, pos = self.patch_embed(image) |
|
|
|
|
|
if self.enc_pos_embed is not None: |
|
|
x = x + self.enc_pos_embed[None, ...] |
|
|
|
|
|
B, N, C = x.size() |
|
|
if do_mask: |
|
|
masks = self.mask_generator(x) |
|
|
x = x[~masks].view(B, -1, C) |
|
|
posvis = pos[~masks].view(B, -1, 2) |
|
|
else: |
|
|
B, N, C = x.size() |
|
|
masks = torch.zeros((B, N), dtype=bool) |
|
|
posvis = pos |
|
|
|
|
|
if return_all_blocks: |
|
|
out = [] |
|
|
for blk in self.enc_blocks: |
|
|
x = blk(x, posvis) |
|
|
out.append(x) |
|
|
out[-1] = self.enc_norm(out[-1]) |
|
|
return out, pos, masks |
|
|
else: |
|
|
for blk in self.enc_blocks: |
|
|
x = blk(x, posvis) |
|
|
x = self.enc_norm(x) |
|
|
return x, pos, masks |
|
|
|
|
|
def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): |
|
|
""" |
|
|
return_all_blocks: if True, return the features at the end of every block |
|
|
instead of just the features from the last block (eg for some prediction heads) |
|
|
|
|
|
masks1 can be None => assume image1 fully visible |
|
|
""" |
|
|
|
|
|
visf1 = self.decoder_embed(feat1) |
|
|
f2 = self.decoder_embed(feat2) |
|
|
|
|
|
B, Nenc, C = visf1.size() |
|
|
if masks1 is None: |
|
|
f1_ = visf1 |
|
|
else: |
|
|
Ntotal = masks1.size(1) |
|
|
f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) |
|
|
f1_[~masks1] = visf1.view(B * Nenc, C) |
|
|
|
|
|
if self.dec_pos_embed is not None: |
|
|
f1_ = f1_ + self.dec_pos_embed |
|
|
f2 = f2 + self.dec_pos_embed |
|
|
|
|
|
out = f1_ |
|
|
out2 = f2 |
|
|
if return_all_blocks: |
|
|
_out, out = out, [] |
|
|
for blk in self.dec_blocks: |
|
|
_out, out2 = blk(_out, out2, pos1, pos2) |
|
|
out.append(_out) |
|
|
out[-1] = self.dec_norm(out[-1]) |
|
|
else: |
|
|
for blk in self.dec_blocks: |
|
|
out, out2 = blk(out, out2, pos1, pos2) |
|
|
out = self.dec_norm(out) |
|
|
return out |
|
|
|
|
|
def patchify(self, imgs): |
|
|
""" |
|
|
imgs: (B, 3, H, W) |
|
|
x: (B, L, patch_size**2 *3) |
|
|
""" |
|
|
p = self.patch_embed.patch_size[0] |
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
|
|
h = w = imgs.shape[2] // p |
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
|
x = torch.einsum("nchpwq->nhwpqc", x) |
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
|
|
|
|
|
return x |
|
|
|
|
|
def unpatchify(self, x, channels=3): |
|
|
""" |
|
|
x: (N, L, patch_size**2 *channels) |
|
|
imgs: (N, 3, H, W) |
|
|
""" |
|
|
patch_size = self.patch_embed.patch_size[0] |
|
|
h = w = int(x.shape[1] ** 0.5) |
|
|
assert h * w == x.shape[1] |
|
|
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) |
|
|
x = torch.einsum("nhwpqc->nchpwq", x) |
|
|
imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) |
|
|
return imgs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|