|
|
import torch.nn as nn |
|
|
import torch |
|
|
import roma |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from functools import cache |
|
|
|
|
|
|
|
|
def todevice(batch, device, callback=None, non_blocking=False): |
|
|
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). |
|
|
|
|
|
batch: list, tuple, dict of tensors or other things |
|
|
device: pytorch device or 'numpy' |
|
|
callback: function that would be called on every sub-elements. |
|
|
""" |
|
|
if callback: |
|
|
batch = callback(batch) |
|
|
|
|
|
if isinstance(batch, dict): |
|
|
return {k: todevice(v, device) for k, v in batch.items()} |
|
|
|
|
|
if isinstance(batch, (tuple, list)): |
|
|
return type(batch)(todevice(x, device) for x in batch) |
|
|
|
|
|
x = batch |
|
|
if device == "numpy": |
|
|
if isinstance(x, torch.Tensor): |
|
|
x = x.detach().cpu().numpy() |
|
|
elif x is not None: |
|
|
if isinstance(x, np.ndarray): |
|
|
x = torch.from_numpy(x) |
|
|
if torch.is_tensor(x): |
|
|
x = x.to(device, non_blocking=non_blocking) |
|
|
return x |
|
|
|
|
|
|
|
|
to_device = todevice |
|
|
|
|
|
|
|
|
def to_numpy(x): |
|
|
return todevice(x, "numpy") |
|
|
|
|
|
|
|
|
def to_cpu(x): |
|
|
return todevice(x, "cpu") |
|
|
|
|
|
|
|
|
def to_cuda(x): |
|
|
return todevice(x, "cuda") |
|
|
|
|
|
|
|
|
def signed_log1p(x): |
|
|
sign = torch.sign(x) |
|
|
return sign * torch.log1p(torch.abs(x)) |
|
|
|
|
|
|
|
|
def l2_dist(a, b, weight): |
|
|
return (a - b).square().sum(dim=-1) * weight |
|
|
|
|
|
|
|
|
def l1_dist(a, b, weight): |
|
|
return (a - b).norm(dim=-1) * weight |
|
|
|
|
|
|
|
|
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) |
|
|
|
|
|
|
|
|
def _check_edges(edges): |
|
|
indices = sorted({i for edge in edges for i in edge}) |
|
|
assert indices == list(range(len(indices))), "bad pair indices: missing values " |
|
|
return len(indices) |
|
|
|
|
|
|
|
|
def NoGradParamDict(x): |
|
|
assert isinstance(x, dict) |
|
|
return nn.ParameterDict(x).requires_grad_(False) |
|
|
|
|
|
|
|
|
def edge_str(i, j): |
|
|
return f"{i}_{j}" |
|
|
|
|
|
|
|
|
def i_j_ij(ij): |
|
|
|
|
|
return edge_str(*ij), ij |
|
|
|
|
|
|
|
|
def edge_conf(conf_i, conf_j): |
|
|
score = float(conf_i.mean() * conf_j.mean()) |
|
|
return score |
|
|
|
|
|
|
|
|
def get_imshapes(edges, pred_i, pred_j): |
|
|
n_imgs = max(max(e) for e in edges) + 1 |
|
|
imshapes = [None] * n_imgs |
|
|
for e, (i, j) in enumerate(edges): |
|
|
shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2]) |
|
|
shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2]) |
|
|
if imshapes[i]: |
|
|
assert imshapes[i] == shape_i, f"incorrect shape for image {i}" |
|
|
if imshapes[j]: |
|
|
assert imshapes[j] == shape_j, f"incorrect shape for image {j}" |
|
|
imshapes[i] = shape_i |
|
|
imshapes[j] = shape_j |
|
|
return imshapes |
|
|
|
|
|
|
|
|
def get_conf_trf(mode): |
|
|
if mode == "log": |
|
|
|
|
|
def conf_trf(x): |
|
|
return x.log() |
|
|
|
|
|
elif mode == "sqrt": |
|
|
|
|
|
def conf_trf(x): |
|
|
return x.sqrt() |
|
|
|
|
|
elif mode == "m1": |
|
|
|
|
|
def conf_trf(x): |
|
|
return x - 1 |
|
|
|
|
|
elif mode in ("id", "none"): |
|
|
|
|
|
def conf_trf(x): |
|
|
return x |
|
|
|
|
|
else: |
|
|
raise ValueError(f"bad mode for {mode=}") |
|
|
return conf_trf |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j): |
|
|
im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes]) |
|
|
for e, (i, j) in enumerate(edges): |
|
|
im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)]) |
|
|
im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)]) |
|
|
return im_conf |
|
|
|
|
|
|
|
|
def xy_grid( |
|
|
W, |
|
|
H, |
|
|
device=None, |
|
|
origin=(0, 0), |
|
|
unsqueeze=None, |
|
|
cat_dim=-1, |
|
|
homogeneous=False, |
|
|
**arange_kw, |
|
|
): |
|
|
"""Output a (H,W,2) array of int32 |
|
|
with output[j,i,0] = i + origin[0] |
|
|
output[j,i,1] = j + origin[1] |
|
|
""" |
|
|
if device is None: |
|
|
|
|
|
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones |
|
|
else: |
|
|
|
|
|
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) |
|
|
meshgrid, stack = torch.meshgrid, torch.stack |
|
|
ones = lambda *a: torch.ones(*a, device=device) |
|
|
|
|
|
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] |
|
|
grid = meshgrid(tw, th, indexing="xy") |
|
|
if homogeneous: |
|
|
grid = grid + (ones((H, W)),) |
|
|
if unsqueeze is not None: |
|
|
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) |
|
|
if cat_dim is not None: |
|
|
grid = stack(grid, cat_dim) |
|
|
return grid |
|
|
|
|
|
|
|
|
def estimate_focal_knowing_depth( |
|
|
pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf |
|
|
): |
|
|
"""Reprojection method, for when the absolute depth is known: |
|
|
1) estimate the camera focal using a robust estimator |
|
|
2) reproject points onto true rays, minimizing a certain error |
|
|
""" |
|
|
B, H, W, THREE = pts3d.shape |
|
|
assert THREE == 3 |
|
|
|
|
|
|
|
|
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view( |
|
|
-1, 1, 2 |
|
|
) |
|
|
pts3d = pts3d.flatten(1, 2) |
|
|
|
|
|
if focal_mode == "median": |
|
|
with torch.no_grad(): |
|
|
|
|
|
u, v = pixels.unbind(dim=-1) |
|
|
x, y, z = pts3d.unbind(dim=-1) |
|
|
fx_votes = (u * z) / x |
|
|
fy_votes = (v * z) / y |
|
|
|
|
|
|
|
|
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) |
|
|
focal = torch.nanmedian(f_votes, dim=-1).values |
|
|
|
|
|
elif focal_mode == "weiszfeld": |
|
|
|
|
|
|
|
|
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num( |
|
|
posinf=0, neginf=0 |
|
|
) |
|
|
|
|
|
dot_xy_px = (xy_over_z * pixels).sum(dim=-1) |
|
|
dot_xy_xy = xy_over_z.square().sum(dim=-1) |
|
|
|
|
|
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) |
|
|
|
|
|
|
|
|
for iter in range(10): |
|
|
|
|
|
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) |
|
|
|
|
|
w = dis.clip(min=1e-8).reciprocal() |
|
|
|
|
|
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) |
|
|
else: |
|
|
raise ValueError(f"bad {focal_mode=}") |
|
|
|
|
|
focal_base = max(H, W) / ( |
|
|
2 * np.tan(np.deg2rad(60) / 2) |
|
|
) |
|
|
focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) |
|
|
|
|
|
return focal |
|
|
|
|
|
|
|
|
def estimate_focal(pts3d_i, pp=None): |
|
|
if pp is None: |
|
|
H, W, THREE = pts3d_i.shape |
|
|
assert THREE == 3 |
|
|
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device) |
|
|
focal = estimate_focal_knowing_depth( |
|
|
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld" |
|
|
).ravel() |
|
|
return float(focal) |
|
|
|
|
|
|
|
|
def rigid_points_registration(pts1, pts2, conf): |
|
|
R, T, s = roma.rigid_points_registration( |
|
|
pts1.reshape(-1, 3), |
|
|
pts2.reshape(-1, 3), |
|
|
weights=conf.ravel(), |
|
|
compute_scaling=True, |
|
|
) |
|
|
return s, R, T |
|
|
|
|
|
|
|
|
def sRT_to_4x4(scale, R, T, device): |
|
|
trf = torch.eye(4, device=device) |
|
|
trf[:3, :3] = R * scale |
|
|
trf[:3, 3] = T.ravel() |
|
|
return trf |
|
|
|
|
|
|
|
|
def geotrf(Trf, pts, ncol=None, norm=False): |
|
|
"""Apply a geometric transformation to a list of 3-D points. |
|
|
|
|
|
H: 3x3 or 4x4 projection matrix (typically a Homography) |
|
|
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) |
|
|
|
|
|
ncol: int. number of columns of the result (2 or 3) |
|
|
norm: float. if != 0, the resut is projected on the z=norm plane. |
|
|
|
|
|
Returns an array of projected 2d points. |
|
|
""" |
|
|
assert Trf.ndim >= 2 |
|
|
if isinstance(Trf, np.ndarray): |
|
|
pts = np.asarray(pts) |
|
|
elif isinstance(Trf, torch.Tensor): |
|
|
pts = torch.as_tensor(pts, dtype=Trf.dtype) |
|
|
|
|
|
|
|
|
output_reshape = pts.shape[:-1] |
|
|
ncol = ncol or pts.shape[-1] |
|
|
|
|
|
|
|
|
if ( |
|
|
isinstance(Trf, torch.Tensor) |
|
|
and isinstance(pts, torch.Tensor) |
|
|
and Trf.ndim == 3 |
|
|
and pts.ndim == 4 |
|
|
): |
|
|
d = pts.shape[3] |
|
|
if Trf.shape[-1] == d: |
|
|
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) |
|
|
elif Trf.shape[-1] == d + 1: |
|
|
pts = ( |
|
|
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) |
|
|
+ Trf[:, None, None, :d, d] |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") |
|
|
else: |
|
|
if Trf.ndim >= 3: |
|
|
n = Trf.ndim - 2 |
|
|
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" |
|
|
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) |
|
|
|
|
|
if pts.ndim > Trf.ndim: |
|
|
|
|
|
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) |
|
|
elif pts.ndim == 2: |
|
|
|
|
|
pts = pts[:, None, :] |
|
|
|
|
|
if pts.shape[-1] + 1 == Trf.shape[-1]: |
|
|
Trf = Trf.swapaxes(-1, -2) |
|
|
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] |
|
|
elif pts.shape[-1] == Trf.shape[-1]: |
|
|
Trf = Trf.swapaxes(-1, -2) |
|
|
pts = pts @ Trf |
|
|
else: |
|
|
pts = Trf @ pts.T |
|
|
if pts.ndim >= 2: |
|
|
pts = pts.swapaxes(-1, -2) |
|
|
|
|
|
if norm: |
|
|
pts = pts / pts[..., -1:] |
|
|
if norm != 1: |
|
|
pts *= norm |
|
|
|
|
|
res = pts[..., :ncol].reshape(*output_reshape, ncol) |
|
|
return res |
|
|
|
|
|
|
|
|
def inv(mat): |
|
|
"""Invert a torch or numpy matrix""" |
|
|
if isinstance(mat, torch.Tensor): |
|
|
return torch.linalg.inv(mat) |
|
|
if isinstance(mat, np.ndarray): |
|
|
return np.linalg.inv(mat) |
|
|
raise ValueError(f"bad matrix type = {type(mat)}") |
|
|
|
|
|
|
|
|
@cache |
|
|
def pixel_grid(H, W): |
|
|
return np.mgrid[:W, :H].T.astype(np.float32) |
|
|
|
|
|
|
|
|
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): |
|
|
|
|
|
if msk.sum() < 4: |
|
|
return None |
|
|
pts3d, msk = map(to_numpy, (pts3d, msk)) |
|
|
|
|
|
H, W, THREE = pts3d.shape |
|
|
assert THREE == 3 |
|
|
pixels = pixel_grid(H, W) |
|
|
|
|
|
if focal is None: |
|
|
S = max(W, H) |
|
|
tentative_focals = np.geomspace(S / 2, S * 3, 21) |
|
|
else: |
|
|
tentative_focals = [focal] |
|
|
|
|
|
if pp is None: |
|
|
pp = (W / 2, H / 2) |
|
|
else: |
|
|
pp = to_numpy(pp) |
|
|
|
|
|
best = (0,) |
|
|
for focal in tentative_focals: |
|
|
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) |
|
|
|
|
|
success, R, T, inliers = cv2.solvePnPRansac( |
|
|
pts3d[msk], |
|
|
pixels[msk], |
|
|
K, |
|
|
None, |
|
|
iterationsCount=niter_PnP, |
|
|
reprojectionError=5, |
|
|
flags=cv2.SOLVEPNP_SQPNP, |
|
|
) |
|
|
if not success: |
|
|
continue |
|
|
|
|
|
score = len(inliers) |
|
|
if success and score > best[0]: |
|
|
best = score, R, T, focal |
|
|
|
|
|
if not best[0]: |
|
|
return None |
|
|
|
|
|
_, R, T, best_focal = best |
|
|
R = cv2.Rodrigues(R)[0] |
|
|
R, T = map(torch.from_numpy, (R, T)) |
|
|
return best_focal, inv(sRT_to_4x4(1, R, T, device)) |
|
|
|
|
|
|
|
|
def get_med_dist_between_poses(poses): |
|
|
from scipy.spatial.distance import pdist |
|
|
|
|
|
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) |
|
|
|
|
|
|
|
|
def align_multiple_poses(src_poses, target_poses): |
|
|
N = len(src_poses) |
|
|
assert src_poses.shape == target_poses.shape == (N, 4, 4) |
|
|
|
|
|
def center_and_z(poses): |
|
|
eps = get_med_dist_between_poses(poses) / 100 |
|
|
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2])) |
|
|
|
|
|
R, T, s = roma.rigid_points_registration( |
|
|
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True |
|
|
) |
|
|
return s, R, T |
|
|
|
|
|
|
|
|
def cosine_schedule(t, lr_start, lr_end): |
|
|
assert 0 <= t <= 1 |
|
|
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2 |
|
|
|
|
|
|
|
|
def linear_schedule(t, lr_start, lr_end): |
|
|
assert 0 <= t <= 1 |
|
|
return lr_start + (lr_end - lr_start) * t |
|
|
|
|
|
|
|
|
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): |
|
|
assert 0 <= t <= 1 |
|
|
cycle_t = t * num_cycles |
|
|
cycle_t = cycle_t - int(cycle_t) |
|
|
if t == 1: |
|
|
cycle_t = 1 |
|
|
return linear_schedule(cycle_t, lr_start, lr_end) |
|
|
|
|
|
|
|
|
def adjust_learning_rate_by_lr(optimizer, lr): |
|
|
for param_group in optimizer.param_groups: |
|
|
if "lr_scale" in param_group: |
|
|
param_group["lr"] = lr * param_group["lr_scale"] |
|
|
else: |
|
|
param_group["lr"] = lr |
|
|
|