| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from einops import rearrange |
| | from typing import List |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | |
| | from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block |
| | from .head_modules import UnetExtractor, AppearanceTransformer, _init_weights |
| | from .postprocess import postprocess |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class DPTOutputAdapter_fix(DPTOutputAdapter): |
| | """ |
| | Adapt croco's DPTOutputAdapter implementation for dust3r: |
| | remove duplicated weigths, and fix forward for dust3r |
| | """ |
| |
|
| | def init(self, dim_tokens_enc=768): |
| | super().init(dim_tokens_enc) |
| | |
| | del self.act_1_postprocess |
| | del self.act_2_postprocess |
| | del self.act_3_postprocess |
| | del self.act_4_postprocess |
| | |
| | self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True) |
| | self.input_merger = nn.Sequential( |
| | |
| | |
| | nn.Conv2d(3, 256, 7, 1, 3), |
| | nn.ReLU(), |
| | ) |
| | |
| | def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None): |
| | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' |
| | |
| | image_size = self.image_size if image_size is None else image_size |
| | H, W = image_size |
| | |
| | N_H = H // (self.stride_level * self.P_H) |
| | N_W = W // (self.stride_level * self.P_W) |
| |
|
| | |
| | layers = [encoder_tokens[hook] for hook in self.hooks] |
| |
|
| | |
| | layers = [self.adapt_tokens(l) for l in layers] |
| |
|
| | |
| | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] |
| |
|
| | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] |
| | |
| | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] |
| | |
| | |
| | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] |
| | path_3 = self.scratch.refinenet3(path_4, layers[2]) |
| | path_2 = self.scratch.refinenet2(path_3, layers[1]) |
| | path_1 = self.scratch.refinenet1(path_2, layers[0]) |
| |
|
| | direct_img_feat = self.input_merger(imgs) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | path_1 = F.interpolate(path_1, size=(H, W), mode='bilinear', align_corners=True) |
| | path_1 = path_1 + direct_img_feat |
| | |
| | |
| |
|
| | |
| | out = self.head(path_1) |
| | |
| | return out, [path_4, path_3, path_2] |
| |
|
| |
|
| | class PixelwiseTaskWithDPT(nn.Module): |
| | """ DPT module for dust3r, can return 3D points + confidence for all pixels""" |
| |
|
| | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, |
| | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): |
| | super(PixelwiseTaskWithDPT, self).__init__() |
| | self.return_all_layers = True |
| | self.postprocess = postprocess |
| | self.depth_mode = depth_mode |
| | self.conf_mode = conf_mode |
| | |
| | assert n_cls_token == 0, "Not implemented" |
| | dpt_args = dict(output_width_ratio=output_width_ratio, |
| | num_channels=num_channels, |
| | **kwargs) |
| | if hooks_idx is not None: |
| | dpt_args.update(hooks=hooks_idx) |
| | self.dpt = DPTOutputAdapter_fix(**dpt_args) |
| | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} |
| | self.dpt.init(**dpt_init_args) |
| |
|
| | def forward(self, x, depths, imgs, img_info, conf=None): |
| | out, interm_feats = self.dpt(x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf) |
| | if self.postprocess: |
| | out = self.postprocess(out, self.depth_mode, self.conf_mode) |
| | return out, interm_feats |
| | |
| | class AttnBasedAppearanceHead(nn.Module): |
| | """ |
| | Attention head Appearence Reconstruction |
| | """ |
| |
|
| | def __init__(self, num_channels, patch_size, feature_dim, last_dim, hooks_idx, dim_tokens, postprocess, depth_mode, conf_mode, head_type='gs_params'): |
| | super().__init__() |
| |
|
| | self.num_channels = num_channels |
| | self.patch_size = patch_size |
| |
|
| | self.hooks = hooks_idx |
| |
|
| | assert len(set(dim_tokens)) == 1 |
| |
|
| | self.tokenizer = nn.Linear(3 * self.patch_size[0] ** 2, dim_tokens[0], bias=False) |
| |
|
| | self.attn_processor = AppearanceTransformer(num_layers=4, attn_dim=dim_tokens[0] * 2, head_dim=feature_dim) |
| |
|
| | self.token_decoder = nn.Sequential( |
| | nn.LayerNorm(dim_tokens[0] * 2, bias=False), |
| | nn.Linear( |
| | dim_tokens[0] * 2, self.num_channels * (self.patch_size[0] ** 2), |
| | bias=False, |
| | ) |
| | ) |
| | self.token_decoder.apply(_init_weights) |
| |
|
| |
|
| | def img_pts_tokenizer(self, imgs, pts3d): |
| | B, V, _, H, W = imgs.shape |
| | pts3d = pts3d.flatten(2, 3).contiguous() |
| | imgs = imgs.permute(0, 1, 3, 4, 2).flatten(2, 3).contiguous() |
| | mean = pts3d.mean(dim=-2, keepdim=True) |
| | z_median = torch.median(torch.norm(pts3d, dim=-1, keepdim=True), dim=2, keepdim=True)[0] |
| | pts3d_normed = (pts3d - mean) / (z_median + 1e-8) |
| |
|
| | input = imgs |
| | |
| | patch_size = self.patch_size |
| | hh = H // patch_size[0] |
| | ww = W // patch_size[1] |
| | input = rearrange(input, "b v (hh ph ww pw) d -> (b v) (hh ww) (ph pw d)", hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1]) |
| | |
| | input_tokens = self.tokenizer(input) |
| | return input_tokens |
| |
|
| | def forward(self, x, depths, imgs, img_info, conf=None): |
| | B, V, H, W = img_info |
| | input_tokens = rearrange(self.img_pts_tokenizer(imgs, depths), "(b v) l d -> b (v l) d", b=B, v=V) |
| |
|
| | |
| | layer_tokens = [rearrange(x[hook].detach(), "(b v) l d -> b (v l) d", b=B, v=V) for hook in self.hooks] |
| |
|
| | tokens = self.attn_processor(torch.cat([input_tokens, layer_tokens[-1]], dim=-1)) |
| |
|
| | gaussian_params = self.token_decoder(tokens) |
| |
|
| | patch_size = self.patch_size |
| | hh = H // patch_size[0] |
| | ww = W // patch_size[1] |
| | gaussians = rearrange(gaussian_params, "b (v hh ww) (ph pw d) -> b (v hh ph ww pw) d", v=V, hh=hh, ww=ww, ph=patch_size[0], pw=patch_size[1]) |
| | return gaussians.view(B, V, H*W, -1) |
| |
|
| | def create_gs_dpt_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess): |
| | """ |
| | return PixelwiseTaskWithDPT for given net params |
| | """ |
| | assert net.dec_depth > 9 |
| | l2 = net.dec_depth |
| | feature_dim = net.feature_dim |
| | last_dim = feature_dim//2 |
| | ed = net.enc_embed_dim |
| | dd = net.dec_embed_dim |
| | try: |
| | patch_size = net.patch_size |
| | except: |
| | patch_size = (16, 16) |
| |
|
| | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, |
| | patch_size=patch_size, |
| | feature_dim=feature_dim, |
| | last_dim=last_dim, |
| | hooks_idx=[0, l2*2//4, l2*3//4, l2], |
| | dim_tokens=[ed, dd, dd, dd], |
| | postprocess=postprocess_func, |
| | depth_mode=net.depth_mode, |
| | conf_mode=net.conf_mode, |
| | head_type='gs_params') |