| |
| |
| |
| |
| |
| |
| from copy import deepcopy |
| import torch |
| import os |
| from packaging import version |
| import huggingface_hub |
|
|
| from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape |
| from .heads import head_factory |
| from dust3r.patch_embed import get_patch_embed |
|
|
| import dust3r.utils.path_to_croco |
| from models.croco import CroCoNet |
|
|
| inf = float('inf') |
|
|
| hf_version_number = huggingface_hub.__version__ |
| assert version.parse(hf_version_number) >= version.parse("0.22.0"), ("Outdated huggingface_hub version, " |
| "please reinstall requirements.txt") |
|
|
|
|
| def load_model(model_path, device, verbose=True): |
| if verbose: |
| print('... loading model from', model_path) |
| ckpt = torch.load(model_path, map_location='cpu') |
| args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") |
| if 'landscape_only' not in args: |
| args = args[:-1] + ', landscape_only=False)' |
| else: |
| args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') |
| assert "landscape_only=False" in args |
| if verbose: |
| print(f"instantiating : {args}") |
| net = eval(args) |
| s = net.load_state_dict(ckpt['model'], strict=False) |
| if verbose: |
| print(s) |
| return net.to(device) |
|
|
|
|
| class AsymmetricCroCo3DStereo ( |
| CroCoNet, |
| huggingface_hub.PyTorchModelHubMixin, |
| library_name="dust3r", |
| repo_url="https://github.com/naver/dust3r", |
| tags=["image-to-3d"], |
| ): |
| """ Two siamese encoders, followed by two decoders. |
| The goal is to output 3d points directly, both images in view1's frame |
| (hence the asymmetry). |
| """ |
|
|
| def __init__(self, |
| output_mode='pts3d', |
| head_type='linear', |
| depth_mode=('exp', -inf, inf), |
| conf_mode=('exp', 1, inf), |
| freeze='none', |
| landscape_only=True, |
| patch_embed_cls='PatchEmbedDust3R', |
| **croco_kwargs): |
| self.patch_embed_cls = patch_embed_cls |
| self.croco_args = fill_default_args(croco_kwargs, super().__init__) |
| super().__init__(**croco_kwargs) |
|
|
| |
| self.dec_blocks2 = deepcopy(self.dec_blocks) |
| self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) |
| self.set_freeze(freeze) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
| if os.path.isfile(pretrained_model_name_or_path): |
| return load_model(pretrained_model_name_or_path, device='cpu') |
| else: |
| try: |
| model = super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
| except TypeError as e: |
| raise Exception(f'tried to load {pretrained_model_name_or_path} from huggingface, but failed') |
| return model |
|
|
| def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): |
| self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) |
|
|
| def load_state_dict(self, ckpt, **kw): |
| |
| new_ckpt = dict(ckpt) |
| if not any(k.startswith('dec_blocks2') for k in ckpt): |
| for key, value in ckpt.items(): |
| if key.startswith('dec_blocks'): |
| new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value |
| return super().load_state_dict(new_ckpt, **kw) |
|
|
| def set_freeze(self, freeze): |
| self.freeze = freeze |
| to_be_frozen = { |
| 'none': [], |
| 'mask': [self.mask_token], |
| 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], |
| } |
| freeze_all_params(to_be_frozen[freeze]) |
|
|
| def _set_prediction_head(self, *args, **kwargs): |
| """ No prediction head """ |
| return |
|
|
| def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, |
| **kw): |
| assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ |
| f'{img_size=} must be multiple of {patch_size=}' |
| self.output_mode = output_mode |
| self.head_type = head_type |
| self.depth_mode = depth_mode |
| self.conf_mode = conf_mode |
| |
| self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
| self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
| |
| self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
| self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
|
| def _encode_image(self, image, true_shape): |
| |
| x, pos = self.patch_embed(image, true_shape=true_shape) |
|
|
| |
| assert self.enc_pos_embed is None |
|
|
| |
| for blk in self.enc_blocks: |
| x = blk(x, pos) |
|
|
| x = self.enc_norm(x) |
| return x, pos, None |
|
|
| def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): |
| if img1.shape[-2:] == img2.shape[-2:]: |
| out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), |
| torch.cat((true_shape1, true_shape2), dim=0)) |
| out, out2 = out.chunk(2, dim=0) |
| pos, pos2 = pos.chunk(2, dim=0) |
| else: |
| out, pos, _ = self._encode_image(img1, true_shape1) |
| out2, pos2, _ = self._encode_image(img2, true_shape2) |
| return out, out2, pos, pos2 |
|
|
| def _encode_symmetrized(self, view1, view2): |
| img1 = view1['img'] |
| img2 = view2['img'] |
| B = img1.shape[0] |
| |
| shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) |
| shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) |
| |
|
|
| if is_symmetrized(view1, view2): |
| |
| feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) |
| feat1, feat2 = interleave(feat1, feat2) |
| pos1, pos2 = interleave(pos1, pos2) |
| else: |
| feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) |
|
|
| return (shape1, shape2), (feat1, feat2), (pos1, pos2) |
|
|
| def _decoder(self, f1, pos1, f2, pos2): |
| final_output = [(f1, f2)] |
|
|
| |
| f1 = self.decoder_embed(f1) |
| f2 = self.decoder_embed(f2) |
|
|
| final_output.append((f1, f2)) |
| for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): |
| |
| f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) |
| |
| f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) |
| |
| final_output.append((f1, f2)) |
|
|
| |
| del final_output[1] |
| final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) |
| return zip(*final_output) |
|
|
| def _downstream_head(self, head_num, decout, img_shape): |
| B, S, D = decout[-1].shape |
| |
| head = getattr(self, f'head{head_num}') |
| return head(decout, img_shape) |
|
|
| def forward(self, view1, view2): |
| |
| (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) |
|
|
| |
| dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) |
|
|
| |
| res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) |
| res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) |
|
|
| res2['pts3d_in_other_view'] = res2.pop('pts3d') |
| return res1, res2 |
|
|