| |
| |
| |
| |
| |
| |
| import torch |
| import torch.nn.functional as F |
|
|
| import mast3r.utils.path_to_dust3r |
| from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf |
| from dust3r.heads.dpt_head import PixelwiseTaskWithDPT |
| import dust3r.utils.path_to_croco |
| from models.blocks import Mlp |
|
|
|
|
| def reg_desc(desc, mode): |
| if 'norm' in mode: |
| desc = desc / desc.norm(dim=-1, keepdim=True) |
| else: |
| raise ValueError(f"Unknown desc mode {mode}") |
| return desc |
|
|
|
|
| def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None): |
| if desc_conf_mode is None: |
| desc_conf_mode = conf_mode |
| fmap = out.permute(0, 2, 3, 1) |
| res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode)) |
| if conf_mode is not None: |
| res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode) |
| if desc_dim is not None: |
| start = 3 + int(conf_mode is not None) |
| res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode) |
| if two_confs: |
| res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode) |
| else: |
| res['desc_conf'] = res['conf'].clone() |
| return res |
|
|
|
|
| class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT): |
| """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP). |
| The input for both heads is a concatenation of Encoder and Decoder outputs |
| """ |
|
|
| def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None, |
| num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs): |
| super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx, |
| dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type) |
| self.local_feat_dim = local_feat_dim |
|
|
| patch_size = net.patch_embed.patch_size |
| if isinstance(patch_size, tuple): |
| assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance( |
| patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." |
| assert patch_size[0] == patch_size[1], "Error, non square patches not managed" |
| patch_size = patch_size[0] |
| self.patch_size = patch_size |
|
|
| self.desc_mode = net.desc_mode |
| self.has_conf = has_conf |
| self.two_confs = net.two_confs |
| self.desc_conf_mode = net.desc_conf_mode |
| idim = net.enc_embed_dim + net.dec_embed_dim |
|
|
| self.head_local_features = Mlp(in_features=idim, |
| hidden_features=int(hidden_dim_factor * idim), |
| out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2) |
|
|
| def forward(self, decout, img_shape): |
| |
| pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1])) |
|
|
| |
| enc_output, dec_output = decout[0], decout[-1] |
| cat_output = torch.cat([enc_output, dec_output], dim=-1) |
| H, W = img_shape |
| B, S, D = cat_output.shape |
|
|
| |
| local_features = self.head_local_features(cat_output) |
| local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size) |
| local_features = F.pixel_shuffle(local_features, self.patch_size) |
|
|
| |
| out = torch.cat([pts3d, local_features], dim=1) |
| if self.postprocess: |
| out = self.postprocess(out, |
| depth_mode=self.depth_mode, |
| conf_mode=self.conf_mode, |
| desc_dim=self.local_feat_dim, |
| desc_mode=self.desc_mode, |
| two_confs=self.two_confs, |
| desc_conf_mode=self.desc_conf_mode) |
| return out |
|
|
|
|
| def mast3r_head_factory(head_type, output_mode, net, has_conf=False): |
| """" build a prediction head for the decoder |
| """ |
| if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): |
| local_feat_dim = int(output_mode[10:]) |
| assert net.dec_depth > 9 |
| l2 = net.dec_depth |
| feature_dim = 256 |
| last_dim = feature_dim // 2 |
| out_nchan = 3 |
| ed = net.enc_embed_dim |
| dd = net.dec_embed_dim |
| return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, |
| num_channels=out_nchan + has_conf, |
| feature_dim=feature_dim, |
| last_dim=last_dim, |
| hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], |
| dim_tokens=[ed, dd, dd, dd], |
| postprocess=postprocess, |
| depth_mode=net.depth_mode, |
| conf_mode=net.conf_mode, |
| head_type='regression') |
| else: |
| raise NotImplementedError( |
| f"unexpected {head_type=} and {output_mode=}") |
|
|