Jiahua0's picture
Upload folder using huggingface_hub
ff47419 verified
import torch.nn as nn
import torch
import roma
import numpy as np
import cv2
from functools import cache
def todevice(batch, device, callback=None, non_blocking=False):
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
batch: list, tuple, dict of tensors or other things
device: pytorch device or 'numpy'
callback: function that would be called on every sub-elements.
"""
if callback:
batch = callback(batch)
if isinstance(batch, dict):
return {k: todevice(v, device) for k, v in batch.items()}
if isinstance(batch, (tuple, list)):
return type(batch)(todevice(x, device) for x in batch)
x = batch
if device == "numpy":
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif x is not None:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if torch.is_tensor(x):
x = x.to(device, non_blocking=non_blocking)
return x
to_device = todevice # alias
def to_numpy(x):
return todevice(x, "numpy")
def to_cpu(x):
return todevice(x, "cpu")
def to_cuda(x):
return todevice(x, "cuda")
def signed_log1p(x):
sign = torch.sign(x)
return sign * torch.log1p(torch.abs(x))
def l2_dist(a, b, weight):
return (a - b).square().sum(dim=-1) * weight
def l1_dist(a, b, weight):
return (a - b).norm(dim=-1) * weight
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
def _check_edges(edges):
indices = sorted({i for edge in edges for i in edge})
assert indices == list(range(len(indices))), "bad pair indices: missing values "
return len(indices)
def NoGradParamDict(x):
assert isinstance(x, dict)
return nn.ParameterDict(x).requires_grad_(False)
def edge_str(i, j):
return f"{i}_{j}"
def i_j_ij(ij):
# inputs are (i, j)
return edge_str(*ij), ij
def edge_conf(conf_i, conf_j):
score = float(conf_i.mean() * conf_j.mean())
return score
def get_imshapes(edges, pred_i, pred_j):
n_imgs = max(max(e) for e in edges) + 1
imshapes = [None] * n_imgs
for e, (i, j) in enumerate(edges):
shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2])
shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2])
if imshapes[i]:
assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
if imshapes[j]:
assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
imshapes[i] = shape_i
imshapes[j] = shape_j
return imshapes
def get_conf_trf(mode):
if mode == "log":
def conf_trf(x):
return x.log()
elif mode == "sqrt":
def conf_trf(x):
return x.sqrt()
elif mode == "m1":
def conf_trf(x):
return x - 1
elif mode in ("id", "none"):
def conf_trf(x):
return x
else:
raise ValueError(f"bad mode for {mode=}")
return conf_trf
@torch.no_grad()
def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j):
im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes])
for e, (i, j) in enumerate(edges):
im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)])
im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)])
return im_conf
def xy_grid(
W,
H,
device=None,
origin=(0, 0),
unsqueeze=None,
cat_dim=-1,
homogeneous=False,
**arange_kw,
):
"""Output a (H,W,2) array of int32
with output[j,i,0] = i + origin[0]
output[j,i,1] = j + origin[1]
"""
if device is None:
# numpy
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
else:
# torch
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
meshgrid, stack = torch.meshgrid, torch.stack
ones = lambda *a: torch.ones(*a, device=device)
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
grid = meshgrid(tw, th, indexing="xy")
if homogeneous:
grid = grid + (ones((H, W)),)
if unsqueeze is not None:
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
if cat_dim is not None:
grid = stack(grid, cat_dim)
return grid
def estimate_focal_knowing_depth(
pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf
):
"""Reprojection method, for when the absolute depth is known:
1) estimate the camera focal using a robust estimator
2) reproject points onto true rays, minimizing a certain error
"""
B, H, W, THREE = pts3d.shape
assert THREE == 3
# centered pixel grid
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(
-1, 1, 2
) # B,HW,2
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
if focal_mode == "median":
with torch.no_grad():
# direct estimation of focal
u, v = pixels.unbind(dim=-1)
x, y, z = pts3d.unbind(dim=-1)
fx_votes = (u * z) / x
fy_votes = (v * z) / y
# assume square pixels, hence same focal for X and Y
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
focal = torch.nanmedian(f_votes, dim=-1).values
elif focal_mode == "weiszfeld":
# init focal with l2 closed form
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(
posinf=0, neginf=0
) # homogeneous (x,y,1)
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
dot_xy_xy = xy_over_z.square().sum(dim=-1)
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
# iterative re-weighted least-squares
for iter in range(10):
# re-weighting by inverse of distance
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
# print(dis.nanmean(-1))
w = dis.clip(min=1e-8).reciprocal()
# update the scaling with the new weights
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
else:
raise ValueError(f"bad {focal_mode=}")
focal_base = max(H, W) / (
2 * np.tan(np.deg2rad(60) / 2)
) # size / 1.1547005383792515
focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base)
# print(focal)
return focal
def estimate_focal(pts3d_i, pp=None):
if pp is None:
H, W, THREE = pts3d_i.shape
assert THREE == 3
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
focal = estimate_focal_knowing_depth(
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
).ravel()
return float(focal)
def rigid_points_registration(pts1, pts2, conf):
R, T, s = roma.rigid_points_registration(
pts1.reshape(-1, 3),
pts2.reshape(-1, 3),
weights=conf.ravel(),
compute_scaling=True,
)
return s, R, T # return un-scaled (R, T)
def sRT_to_4x4(scale, R, T, device):
trf = torch.eye(4, device=device)
trf[:3, :3] = R * scale
trf[:3, 3] = T.ravel() # doesn't need scaling
return trf
def geotrf(Trf, pts, ncol=None, norm=False):
"""Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (
isinstance(Trf, torch.Tensor)
and isinstance(pts, torch.Tensor)
and Trf.ndim == 3
and pts.ndim == 4
):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = (
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ Trf[:, None, None, :d, d]
)
else:
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def inv(mat):
"""Invert a torch or numpy matrix"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f"bad matrix type = {type(mat)}")
@cache
def pixel_grid(H, W):
return np.mgrid[:W, :H].T.astype(np.float32)
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
# extract camera poses and focals with RANSAC-PnP
if msk.sum() < 4:
return None # we need at least 4 points for PnP
pts3d, msk = map(to_numpy, (pts3d, msk))
H, W, THREE = pts3d.shape
assert THREE == 3
pixels = pixel_grid(H, W)
if focal is None:
S = max(W, H)
tentative_focals = np.geomspace(S / 2, S * 3, 21)
else:
tentative_focals = [focal]
if pp is None:
pp = (W / 2, H / 2)
else:
pp = to_numpy(pp)
best = (0,)
for focal in tentative_focals:
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
success, R, T, inliers = cv2.solvePnPRansac(
pts3d[msk],
pixels[msk],
K,
None,
iterationsCount=niter_PnP,
reprojectionError=5,
flags=cv2.SOLVEPNP_SQPNP,
)
if not success:
continue
score = len(inliers)
if success and score > best[0]:
best = score, R, T, focal
if not best[0]:
return None
_, R, T, best_focal = best
R = cv2.Rodrigues(R)[0] # world to cam
R, T = map(torch.from_numpy, (R, T))
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
def get_med_dist_between_poses(poses):
from scipy.spatial.distance import pdist
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
def align_multiple_poses(src_poses, target_poses):
N = len(src_poses)
assert src_poses.shape == target_poses.shape == (N, 4, 4)
def center_and_z(poses):
eps = get_med_dist_between_poses(poses) / 100
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
R, T, s = roma.rigid_points_registration(
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
)
return s, R, T
def cosine_schedule(t, lr_start, lr_end):
assert 0 <= t <= 1
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
def linear_schedule(t, lr_start, lr_end):
assert 0 <= t <= 1
return lr_start + (lr_end - lr_start) * t
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2):
assert 0 <= t <= 1
cycle_t = t * num_cycles
cycle_t = cycle_t - int(cycle_t)
if t == 1:
cycle_t = 1
return linear_schedule(cycle_t, lr_start, lr_end)
def adjust_learning_rate_by_lr(optimizer, lr):
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr