| | from copy import deepcopy |
| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | from einops import rearrange |
| | from torch import nn |
| |
|
| | from .croco.blocks import DecoderBlock |
| | from .croco.croco import CroCoNet |
| | from .croco.misc import fill_default_args, freeze_all_params, transpose_to_landscape, is_symmetrized, interleave, \ |
| | make_batch_symmetric |
| | from .croco.patch_embed import get_patch_embed |
| | from .backbone import Backbone |
| | from src.geometry.camera_emb import get_intrinsic_embedding |
| |
|
| | inf = float('inf') |
| |
|
| |
|
| | croco_params = { |
| | 'ViTLarge_BaseDecoder': { |
| | 'enc_depth': 24, |
| | 'dec_depth': 12, |
| | 'enc_embed_dim': 1024, |
| | 'dec_embed_dim': 768, |
| | 'enc_num_heads': 16, |
| | 'dec_num_heads': 12, |
| | 'pos_embed': 'RoPE100', |
| | 'img_size': (512, 512), |
| | }, |
| | } |
| |
|
| | default_dust3r_params = { |
| | 'enc_depth': 24, |
| | 'dec_depth': 12, |
| | 'enc_embed_dim': 1024, |
| | 'dec_embed_dim': 768, |
| | 'enc_num_heads': 16, |
| | 'dec_num_heads': 12, |
| | 'pos_embed': 'RoPE100', |
| | 'patch_embed_cls': 'PatchEmbedDust3R', |
| | 'img_size': (512, 512), |
| | 'head_type': 'dpt', |
| | 'output_mode': 'pts3d', |
| | 'depth_mode': ('exp', -inf, inf), |
| | 'conf_mode': ('exp', 1, inf) |
| | } |
| |
|
| |
|
| | @dataclass |
| | class BackboneCrocoCfg: |
| | name: Literal["croco"] |
| | model: Literal["ViTLarge_BaseDecoder", "ViTBase_SmallDecoder", "ViTBase_BaseDecoder"] |
| | patch_embed_cls: str = 'PatchEmbedDust3R' |
| | asymmetry_decoder: bool = True |
| | intrinsics_embed_loc: Literal["encoder", "decoder", "none"] = 'none' |
| | intrinsics_embed_degree: int = 0 |
| | intrinsics_embed_type: Literal["pixelwise", "linear", "token"] = 'token' |
| |
|
| |
|
| | class AsymmetricCroCoMulti(CroCoNet): |
| | """ Two siamese encoders, followed by two decoders. |
| | The goal is to output 3d points directly, both images in view1's frame |
| | (hence the asymmetry). |
| | """ |
| |
|
| | def __init__(self, cfg: BackboneCrocoCfg, d_in: int) -> None: |
| |
|
| | self.intrinsics_embed_loc = cfg.intrinsics_embed_loc |
| | self.intrinsics_embed_degree = cfg.intrinsics_embed_degree |
| | self.intrinsics_embed_type = cfg.intrinsics_embed_type |
| | self.intrinsics_embed_encoder_dim = 0 |
| | self.intrinsics_embed_decoder_dim = 0 |
| | if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise': |
| | self.intrinsics_embed_encoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3 |
| | elif self.intrinsics_embed_loc == 'decoder' and self.intrinsics_embed_type == 'pixelwise': |
| | self.intrinsics_embed_decoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3 |
| |
|
| | self.patch_embed_cls = cfg.patch_embed_cls |
| | self.croco_args = fill_default_args(croco_params[cfg.model], CroCoNet.__init__) |
| |
|
| | super().__init__(**croco_params[cfg.model]) |
| |
|
| | if cfg.asymmetry_decoder: |
| | self.dec_blocks2 = deepcopy(self.dec_blocks) |
| |
|
| | if self.intrinsics_embed_type == 'linear' or self.intrinsics_embed_type == 'token': |
| | self.intrinsic_encoder = nn.Linear(9, 1024) |
| |
|
| | |
| | |
| |
|
| | def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768, in_chans=3): |
| | in_chans = in_chans + self.intrinsics_embed_encoder_dim |
| | self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans) |
| |
|
| | def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): |
| | self.dec_depth = dec_depth |
| | self.dec_embed_dim = dec_embed_dim |
| | |
| | enc_embed_dim = enc_embed_dim + self.intrinsics_embed_decoder_dim |
| | self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) |
| | |
| | self.dec_blocks = nn.ModuleList([ |
| | DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) |
| | for i in range(dec_depth)]) |
| | |
| | self.dec_norm = norm_layer(dec_embed_dim) |
| |
|
| | def load_state_dict(self, ckpt, **kw): |
| | |
| | new_ckpt = dict(ckpt) |
| | if not any(k.startswith('dec_blocks2') for k in ckpt): |
| | for key, value in ckpt.items(): |
| | if key.startswith('dec_blocks'): |
| | new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value |
| | return super().load_state_dict(new_ckpt, **kw) |
| |
|
| | def set_freeze(self, freeze): |
| | assert freeze in ['none', 'mask', 'encoder'], f"unexpected freeze={freeze}" |
| | to_be_frozen = { |
| | 'none': [], |
| | 'mask': [self.mask_token], |
| | 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], |
| | 'encoder_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.enc_norm, self.decoder_embed, self.dec_blocks, self.dec_blocks2, self.dec_norm], |
| | } |
| | freeze_all_params(to_be_frozen[freeze]) |
| |
|
| | def _set_prediction_head(self, *args, **kwargs): |
| | """ No prediction head """ |
| | return |
| |
|
| | def _encode_image(self, image, true_shape, intrinsics_embed=None): |
| | |
| | x, pos = self.patch_embed(image, true_shape=true_shape) |
| |
|
| | if intrinsics_embed is not None: |
| |
|
| | if self.intrinsics_embed_type == 'linear': |
| | x = x + intrinsics_embed |
| | elif self.intrinsics_embed_type == 'token': |
| | x = torch.cat((x, intrinsics_embed), dim=1) |
| | add_pose = pos[:, 0:1, :].clone() |
| | add_pose[:, :, 0] += (pos[:, -1, 0].unsqueeze(-1) + 1) |
| | pos = torch.cat((pos, add_pose), dim=1) |
| |
|
| | |
| | assert self.enc_pos_embed is None |
| |
|
| | |
| | for blk in self.enc_blocks: |
| | x = blk(x, pos) |
| |
|
| | x = self.enc_norm(x) |
| | return x, pos, None |
| |
|
| | def _decoder(self, feat, pose, extra_embed=None): |
| | b, v, l, c = feat.shape |
| | final_output = [feat] |
| | if extra_embed is not None: |
| | feat = torch.cat((feat, extra_embed), dim=-1) |
| |
|
| | |
| | f = rearrange(feat, "b v l c -> (b v) l c") |
| | f = self.decoder_embed(f) |
| | f = rearrange(f, "(b v) l c -> b v l c", b=b, v=v) |
| | final_output.append(f) |
| |
|
| | def generate_ctx_views(x): |
| | b, v, l, c = x.shape |
| | ctx_views = x.unsqueeze(1).expand(b, v, v, l, c) |
| | mask = torch.arange(v).unsqueeze(0) != torch.arange(v).unsqueeze(1) |
| | ctx_views = ctx_views[:, mask].reshape(b, v, v - 1, l, c) |
| | ctx_views = ctx_views.flatten(2, 3) |
| | return ctx_views.contiguous() |
| |
|
| | pos_ctx = generate_ctx_views(pose) |
| | for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): |
| | feat_current = final_output[-1] |
| | feat_current_ctx = generate_ctx_views(feat_current) |
| | |
| | f1, _ = blk1(feat_current[:, 0].contiguous(), feat_current_ctx[:, 0].contiguous(), pose[:, 0].contiguous(), pos_ctx[:, 0].contiguous()) |
| | f1 = f1.unsqueeze(1) |
| | |
| | f2, _ = blk2(rearrange(feat_current[:, 1:], "b v l c -> (b v) l c"), |
| | rearrange(feat_current_ctx[:, 1:], "b v l c -> (b v) l c"), |
| | rearrange(pose[:, 1:], "b v l c -> (b v) l c"), |
| | rearrange(pos_ctx[:, 1:], "b v l c -> (b v) l c")) |
| | f2 = rearrange(f2, "(b v) l c -> b v l c", b=b, v=v-1) |
| | |
| | final_output.append(torch.cat((f1, f2), dim=1)) |
| |
|
| | |
| | del final_output[1] |
| | last_feat = rearrange(final_output[-1], "b v l c -> (b v) l c") |
| | last_feat = self.dec_norm(last_feat) |
| | final_output[-1] = rearrange(last_feat, "(b v) l c -> b v l c", b=b, v=v) |
| | return final_output |
| |
|
| | def forward(self, |
| | context: dict, |
| | symmetrize_batch=False, |
| | return_views=False, |
| | ): |
| | b, v, _, h, w = context["image"].shape |
| | images_all = context["image"] |
| |
|
| | |
| | if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise': |
| | intrinsic_embedding = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree) |
| | images_all = torch.cat((images_all, intrinsic_embedding), dim=2) |
| |
|
| | intrinsic_embedding_all = None |
| | if self.intrinsics_embed_loc == 'encoder' and (self.intrinsics_embed_type == 'token' or self.intrinsics_embed_type == 'linear'): |
| | intrinsic_embedding = self.intrinsic_encoder(context["intrinsics"].flatten(2)) |
| | intrinsic_embedding_all = rearrange(intrinsic_embedding, "b v c -> (b v) c").unsqueeze(1) |
| |
|
| | |
| | images_all = rearrange(images_all, "b v c h w -> (b v) c h w") |
| | shape_all = torch.tensor(images_all.shape[-2:])[None].repeat(b*v, 1) |
| |
|
| | feat, pose, _ = self._encode_image(images_all, shape_all, intrinsic_embedding_all) |
| |
|
| | feat = rearrange(feat, "(b v) l c -> b v l c", b=b, v=v) |
| | pose = rearrange(pose, "(b v) l c -> b v l c", b=b, v=v) |
| |
|
| | |
| | dec_feat = self._decoder(feat, pose) |
| | shape = rearrange(shape_all, "(b v) c -> b v c", b=b, v=v) |
| | images = rearrange(images_all, "(b v) c h w -> b v c h w", b=b, v=v) |
| |
|
| | if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'token': |
| | dec_feat = list(dec_feat) |
| | for i in range(len(dec_feat)): |
| | dec_feat[i] = dec_feat[i][:, :, :-1] |
| |
|
| | return dec_feat, shape, images |
| |
|
| | @property |
| | def patch_size(self) -> int: |
| | return 16 |
| |
|
| | @property |
| | def d_out(self) -> int: |
| | return 1024 |
| |
|