Spaces:
Runtime error
Runtime error
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from .croco.blocks import DecoderBlock | |
| from .croco.croco import CroCoNet | |
| from .croco.misc import fill_default_args, freeze_all_params, transpose_to_landscape, is_symmetrized, interleave, \ | |
| make_batch_symmetric | |
| from .croco.patch_embed import get_patch_embed | |
| from .backbone import Backbone | |
| from src.geometry.camera_emb import get_intrinsic_embedding | |
| inf = float('inf') | |
| croco_params = { | |
| 'ViTLarge_BaseDecoder': { | |
| 'enc_depth': 24, | |
| 'dec_depth': 12, | |
| 'enc_embed_dim': 1024, | |
| 'dec_embed_dim': 768, | |
| 'enc_num_heads': 16, | |
| 'dec_num_heads': 12, | |
| 'pos_embed': 'RoPE100', | |
| 'img_size': (512, 512), | |
| }, | |
| } | |
| default_dust3r_params = { | |
| 'enc_depth': 24, | |
| 'dec_depth': 12, | |
| 'enc_embed_dim': 1024, | |
| 'dec_embed_dim': 768, | |
| 'enc_num_heads': 16, | |
| 'dec_num_heads': 12, | |
| 'pos_embed': 'RoPE100', | |
| 'patch_embed_cls': 'PatchEmbedDust3R', | |
| 'img_size': (512, 512), | |
| 'head_type': 'dpt', | |
| 'output_mode': 'pts3d', | |
| 'depth_mode': ('exp', -inf, inf), | |
| 'conf_mode': ('exp', 1, inf) | |
| } | |
| class BackboneCrocoCfg: | |
| name: Literal["croco"] | |
| model: Literal["ViTLarge_BaseDecoder", "ViTBase_SmallDecoder", "ViTBase_BaseDecoder"] # keep interface for the last two models, but they are not supported | |
| patch_embed_cls: str = 'PatchEmbedDust3R' # PatchEmbedDust3R or ManyAR_PatchEmbed | |
| asymmetry_decoder: bool = True | |
| intrinsics_embed_loc: Literal["encoder", "decoder", "none"] = 'none' | |
| intrinsics_embed_degree: int = 0 | |
| intrinsics_embed_type: Literal["pixelwise", "linear", "token"] = 'token' # linear or dpt | |
| class AsymmetricCroCoMulti(CroCoNet): | |
| """ Two siamese encoders, followed by two decoders. | |
| The goal is to output 3d points directly, both images in view1's frame | |
| (hence the asymmetry). | |
| """ | |
| def __init__(self, cfg: BackboneCrocoCfg, d_in: int) -> None: | |
| self.intrinsics_embed_loc = cfg.intrinsics_embed_loc | |
| self.intrinsics_embed_degree = cfg.intrinsics_embed_degree | |
| self.intrinsics_embed_type = cfg.intrinsics_embed_type | |
| self.intrinsics_embed_encoder_dim = 0 | |
| self.intrinsics_embed_decoder_dim = 0 | |
| if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise': | |
| self.intrinsics_embed_encoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3 | |
| elif self.intrinsics_embed_loc == 'decoder' and self.intrinsics_embed_type == 'pixelwise': | |
| self.intrinsics_embed_decoder_dim = (self.intrinsics_embed_degree + 1) ** 2 if self.intrinsics_embed_degree > 0 else 3 | |
| self.patch_embed_cls = cfg.patch_embed_cls | |
| self.croco_args = fill_default_args(croco_params[cfg.model], CroCoNet.__init__) | |
| super().__init__(**croco_params[cfg.model]) | |
| if cfg.asymmetry_decoder: | |
| self.dec_blocks2 = deepcopy(self.dec_blocks) # This is used in DUSt3R and MASt3R | |
| if self.intrinsics_embed_type == 'linear' or self.intrinsics_embed_type == 'token': | |
| self.intrinsic_encoder = nn.Linear(9, 1024) | |
| # self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) | |
| # self.set_freeze(freeze) | |
| def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768, in_chans=3): | |
| in_chans = in_chans + self.intrinsics_embed_encoder_dim | |
| self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans) | |
| def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): | |
| self.dec_depth = dec_depth | |
| self.dec_embed_dim = dec_embed_dim | |
| # transfer from encoder to decoder | |
| enc_embed_dim = enc_embed_dim + self.intrinsics_embed_decoder_dim | |
| self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) | |
| # transformer for the decoder | |
| self.dec_blocks = nn.ModuleList([ | |
| DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) | |
| for i in range(dec_depth)]) | |
| # final norm layer | |
| self.dec_norm = norm_layer(dec_embed_dim) | |
| def load_state_dict(self, ckpt, **kw): | |
| # duplicate all weights for the second decoder if not present | |
| new_ckpt = dict(ckpt) | |
| if not any(k.startswith('dec_blocks2') for k in ckpt): | |
| for key, value in ckpt.items(): | |
| if key.startswith('dec_blocks'): | |
| new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value | |
| return super().load_state_dict(new_ckpt, **kw) | |
| def set_freeze(self, freeze): # this is for use by downstream models | |
| assert freeze in ['none', 'mask', 'encoder'], f"unexpected freeze={freeze}" | |
| to_be_frozen = { | |
| 'none': [], | |
| 'mask': [self.mask_token], | |
| 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], | |
| 'encoder_decoder': [self.mask_token, self.patch_embed, self.enc_blocks, self.enc_norm, self.decoder_embed, self.dec_blocks, self.dec_blocks2, self.dec_norm], | |
| } | |
| freeze_all_params(to_be_frozen[freeze]) | |
| def _set_prediction_head(self, *args, **kwargs): | |
| """ No prediction head """ | |
| return | |
| def _encode_image(self, image, true_shape, intrinsics_embed=None): | |
| # embed the image into patches (x has size B x Npatches x C) | |
| x, pos = self.patch_embed(image, true_shape=true_shape) | |
| if intrinsics_embed is not None: | |
| if self.intrinsics_embed_type == 'linear': | |
| x = x + intrinsics_embed | |
| elif self.intrinsics_embed_type == 'token': | |
| x = torch.cat((x, intrinsics_embed), dim=1) | |
| add_pose = pos[:, 0:1, :].clone() | |
| add_pose[:, :, 0] += (pos[:, -1, 0].unsqueeze(-1) + 1) | |
| pos = torch.cat((pos, add_pose), dim=1) | |
| # add positional embedding without cls token | |
| assert self.enc_pos_embed is None | |
| # now apply the transformer encoder and normalization | |
| for blk in self.enc_blocks: | |
| x = blk(x, pos) | |
| x = self.enc_norm(x) | |
| return x, pos, None | |
| def _decoder(self, feat, pose, extra_embed=None): | |
| b, v, l, c = feat.shape | |
| final_output = [feat] # before projection | |
| if extra_embed is not None: | |
| feat = torch.cat((feat, extra_embed), dim=-1) | |
| # project to decoder dim | |
| f = rearrange(feat, "b v l c -> (b v) l c") | |
| f = self.decoder_embed(f) | |
| f = rearrange(f, "(b v) l c -> b v l c", b=b, v=v) | |
| final_output.append(f) | |
| def generate_ctx_views(x): | |
| b, v, l, c = x.shape | |
| ctx_views = x.unsqueeze(1).expand(b, v, v, l, c) | |
| mask = torch.arange(v).unsqueeze(0) != torch.arange(v).unsqueeze(1) | |
| ctx_views = ctx_views[:, mask].reshape(b, v, v - 1, l, c) # B, V, V-1, L, C | |
| ctx_views = ctx_views.flatten(2, 3) # B, V, (V-1)*L, C | |
| return ctx_views.contiguous() | |
| pos_ctx = generate_ctx_views(pose) | |
| for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): | |
| feat_current = final_output[-1] | |
| feat_current_ctx = generate_ctx_views(feat_current) | |
| # img1 side | |
| f1, _ = blk1(feat_current[:, 0].contiguous(), feat_current_ctx[:, 0].contiguous(), pose[:, 0].contiguous(), pos_ctx[:, 0].contiguous()) | |
| f1 = f1.unsqueeze(1) | |
| # img2 side | |
| f2, _ = blk2(rearrange(feat_current[:, 1:], "b v l c -> (b v) l c"), | |
| rearrange(feat_current_ctx[:, 1:], "b v l c -> (b v) l c"), | |
| rearrange(pose[:, 1:], "b v l c -> (b v) l c"), | |
| rearrange(pos_ctx[:, 1:], "b v l c -> (b v) l c")) | |
| f2 = rearrange(f2, "(b v) l c -> b v l c", b=b, v=v-1) | |
| # store the result | |
| final_output.append(torch.cat((f1, f2), dim=1)) | |
| # normalize last output | |
| del final_output[1] # duplicate with final_output[0] | |
| last_feat = rearrange(final_output[-1], "b v l c -> (b v) l c") | |
| last_feat = self.dec_norm(last_feat) | |
| final_output[-1] = rearrange(last_feat, "(b v) l c -> b v l c", b=b, v=v) | |
| return final_output | |
| def forward(self, | |
| context: dict, | |
| symmetrize_batch=False, | |
| return_views=False, | |
| ): | |
| b, v, _, h, w = context["image"].shape | |
| images_all = context["image"] | |
| # camera embedding in the encoder | |
| if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'pixelwise': | |
| intrinsic_embedding = get_intrinsic_embedding(context, degree=self.intrinsics_embed_degree) | |
| images_all = torch.cat((images_all, intrinsic_embedding), dim=2) | |
| intrinsic_embedding_all = None | |
| if self.intrinsics_embed_loc == 'encoder' and (self.intrinsics_embed_type == 'token' or self.intrinsics_embed_type == 'linear'): | |
| intrinsic_embedding = self.intrinsic_encoder(context["intrinsics"].flatten(2)) | |
| intrinsic_embedding_all = rearrange(intrinsic_embedding, "b v c -> (b v) c").unsqueeze(1) | |
| # step 1: encoder input images | |
| images_all = rearrange(images_all, "b v c h w -> (b v) c h w") | |
| shape_all = torch.tensor(images_all.shape[-2:])[None].repeat(b*v, 1) | |
| feat, pose, _ = self._encode_image(images_all, shape_all, intrinsic_embedding_all) | |
| feat = rearrange(feat, "(b v) l c -> b v l c", b=b, v=v) | |
| pose = rearrange(pose, "(b v) l c -> b v l c", b=b, v=v) | |
| # step 2: decoder | |
| dec_feat = self._decoder(feat, pose) | |
| shape = rearrange(shape_all, "(b v) c -> b v c", b=b, v=v) | |
| images = rearrange(images_all, "(b v) c h w -> b v c h w", b=b, v=v) | |
| if self.intrinsics_embed_loc == 'encoder' and self.intrinsics_embed_type == 'token': | |
| dec_feat = list(dec_feat) | |
| for i in range(len(dec_feat)): | |
| dec_feat[i] = dec_feat[i][:, :, :-1] | |
| return dec_feat, shape, images | |
| def patch_size(self) -> int: | |
| return 16 | |
| def d_out(self) -> int: | |
| return 1024 | |