| |
| |
| |
| |
| |
| |
| import torch |
| import torch.nn.functional as F |
| import os |
|
|
| from mast3r.catmlp_dpt_head import mast3r_head_factory |
|
|
| import mast3r.utils.path_to_dust3r |
| from ..dust3r.dust3r.model import AsymmetricCroCo3DStereo |
| from ..dust3r.dust3r.utils.misc import transpose_to_landscape |
|
|
|
|
| inf = float('inf') |
|
|
|
|
| def load_model(model_path, device, verbose=True): |
| if verbose: |
| print('... loading model from', model_path) |
| ckpt = torch.load(model_path, map_location='cpu') |
| args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") |
| if 'landscape_only' not in args: |
| args = args[:-1] + ', landscape_only=False)' |
| else: |
| args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') |
| assert "landscape_only=False" in args |
| if verbose: |
| print(f"instantiating : {args}") |
| net = eval(args) |
| s = net.load_state_dict(ckpt['model'], strict=False) |
| if verbose: |
| print(s) |
| return net.to(device) |
|
|
|
|
| class AsymmetricMASt3R(AsymmetricCroCo3DStereo): |
| def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): |
| self.desc_mode = desc_mode |
| self.two_confs = two_confs |
| self.desc_conf_mode = desc_conf_mode |
| super().__init__(**kwargs) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kw): |
| if os.path.isfile(pretrained_model_name_or_path): |
| return load_model(pretrained_model_name_or_path, device='cpu') |
| else: |
| return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) |
|
|
| def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): |
| |
| |
| self.output_mode = output_mode |
| self.head_type = head_type |
| self.depth_mode = depth_mode |
| self.conf_mode = conf_mode |
| if self.desc_conf_mode is None: |
| self.desc_conf_mode = conf_mode |
| |
| self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
| self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) |
| |
| self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) |
| self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) |
|
|