Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # utilitary functions for DUSt3R | |
| # -------------------------------------------------------- | |
| import torch | |
| def fill_default_args(kwargs, func): | |
| import inspect # a bit hacky but it works reliably | |
| signature = inspect.signature(func) | |
| for k, v in signature.parameters.items(): | |
| if v.default is inspect.Parameter.empty: | |
| continue | |
| kwargs.setdefault(k, v.default) | |
| return kwargs | |
| def freeze_all_params(modules): | |
| for module in modules: | |
| try: | |
| for n, param in module.named_parameters(): | |
| param.requires_grad = False | |
| except AttributeError: | |
| # module is directly a parameter | |
| module.requires_grad = False | |
| def is_symmetrized(gt1, gt2): | |
| x = gt1['instance'] | |
| y = gt2['instance'] | |
| if len(x) == len(y) and len(x) == 1: | |
| return False # special case of batchsize 1 | |
| ok = True | |
| for i in range(0, len(x), 2): | |
| ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) | |
| return ok | |
| def flip(tensor): | |
| """ flip so that tensor[0::2] <=> tensor[1::2] """ | |
| return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) | |
| def interleave(tensor1, tensor2): | |
| res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) | |
| res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) | |
| return res1, res2 | |
| def transpose_to_landscape(head, activate=True): | |
| """ Predict in the correct aspect-ratio, | |
| then transpose the result in landscape | |
| and stack everything back together. | |
| """ | |
| def wrapper_no(decout, true_shape): | |
| B = len(true_shape) | |
| assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' | |
| H, W = true_shape[0].cpu().tolist() | |
| res = head(decout, (H, W)) | |
| return res | |
| def wrapper_yes(decout, true_shape): | |
| B = len(true_shape) | |
| # by definition, the batch is in landscape mode so W >= H | |
| H, W = int(true_shape.min()), int(true_shape.max()) | |
| height, width = true_shape.T | |
| is_landscape = (width >= height) | |
| is_portrait = ~is_landscape | |
| # true_shape = true_shape.cpu() | |
| if is_landscape.all(): | |
| return head(decout, (H, W)) | |
| if is_portrait.all(): | |
| return transposed(head(decout, (W, H))) | |
| # batch is a mix of both portraint & landscape | |
| def selout(ar): return [d[ar] for d in decout] | |
| l_result = head(selout(is_landscape), (H, W)) | |
| p_result = transposed(head(selout(is_portrait), (W, H))) | |
| # allocate full result | |
| result = {} | |
| for k in l_result | p_result: | |
| x = l_result[k].new(B, *l_result[k].shape[1:]) | |
| x[is_landscape] = l_result[k] | |
| x[is_portrait] = p_result[k] | |
| result[k] = x | |
| return result | |
| return wrapper_yes if activate else wrapper_no | |
| def transposed(dic): | |
| return {k: v.swapaxes(1, 2) for k, v in dic.items()} | |
| def invalid_to_nans(arr, valid_mask, ndim=999): | |
| if valid_mask is not None: | |
| arr = arr.clone() | |
| arr[~valid_mask] = float('nan') | |
| if arr.ndim > ndim: | |
| arr = arr.flatten(-2 - (arr.ndim - ndim), -2) | |
| return arr | |
| def invalid_to_zeros(arr, valid_mask, ndim=999): | |
| if valid_mask is not None: | |
| arr = arr.clone() | |
| arr[~valid_mask] = 0 | |
| nnz = valid_mask.view(len(valid_mask), -1).sum(1) | |
| else: | |
| nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image | |
| if arr.ndim > ndim: | |
| arr = arr.flatten(-2 - (arr.ndim - ndim), -2) | |
| return arr, nnz | |