Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files
demo.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 4 |
+
#
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# sparse gradio demo functions
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
import math
|
| 9 |
+
import gradio
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import functools
|
| 13 |
+
import trimesh
|
| 14 |
+
import copy
|
| 15 |
+
from scipy.spatial.transform import Rotation
|
| 16 |
+
import tempfile
|
| 17 |
+
import shutil
|
| 18 |
+
|
| 19 |
+
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
| 20 |
+
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
| 21 |
+
|
| 22 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 23 |
+
from dust3r.image_pairs import make_pairs
|
| 24 |
+
from dust3r.utils.image import load_images
|
| 25 |
+
from dust3r.utils.device import to_numpy
|
| 26 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
| 27 |
+
from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
| 28 |
+
|
| 29 |
+
import matplotlib.pyplot as pl
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SparseGAState():
|
| 33 |
+
def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
|
| 34 |
+
self.sparse_ga = sparse_ga
|
| 35 |
+
self.cache_dir = cache_dir
|
| 36 |
+
self.outfile_name = outfile_name
|
| 37 |
+
self.should_delete = should_delete
|
| 38 |
+
|
| 39 |
+
def __del__(self):
|
| 40 |
+
if not self.should_delete:
|
| 41 |
+
return
|
| 42 |
+
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
|
| 43 |
+
shutil.rmtree(self.cache_dir)
|
| 44 |
+
self.cache_dir = None
|
| 45 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
| 46 |
+
os.remove(self.outfile_name)
|
| 47 |
+
self.outfile_name = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_args_parser():
|
| 51 |
+
parser = dust3r_get_args_parser()
|
| 52 |
+
parser.add_argument('--share', action='store_true')
|
| 53 |
+
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
| 54 |
+
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
| 55 |
+
|
| 56 |
+
actions = parser._actions
|
| 57 |
+
for action in actions:
|
| 58 |
+
if action.dest == 'model_name':
|
| 59 |
+
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
|
| 60 |
+
# change defaults
|
| 61 |
+
parser.prog = 'mast3r demo'
|
| 62 |
+
return parser
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
| 66 |
+
cam_color=None, as_pointcloud=False,
|
| 67 |
+
transparent_cams=False, silent=False):
|
| 68 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
| 69 |
+
pts3d = to_numpy(pts3d)
|
| 70 |
+
imgs = to_numpy(imgs)
|
| 71 |
+
focals = to_numpy(focals)
|
| 72 |
+
cams2world = to_numpy(cams2world)
|
| 73 |
+
|
| 74 |
+
scene = trimesh.Scene()
|
| 75 |
+
|
| 76 |
+
# full pointcloud
|
| 77 |
+
if as_pointcloud:
|
| 78 |
+
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
|
| 79 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
|
| 80 |
+
valid_msk = np.isfinite(pts.sum(axis=1))
|
| 81 |
+
pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
|
| 82 |
+
scene.add_geometry(pct)
|
| 83 |
+
else:
|
| 84 |
+
meshes = []
|
| 85 |
+
for i in range(len(imgs)):
|
| 86 |
+
pts3d_i = pts3d[i].reshape(imgs[i].shape)
|
| 87 |
+
msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
|
| 88 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
|
| 89 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 90 |
+
scene.add_geometry(mesh)
|
| 91 |
+
|
| 92 |
+
# add each camera
|
| 93 |
+
for i, pose_c2w in enumerate(cams2world):
|
| 94 |
+
if isinstance(cam_color, list):
|
| 95 |
+
camera_edge_color = cam_color[i]
|
| 96 |
+
else:
|
| 97 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
| 98 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
| 99 |
+
None if transparent_cams else imgs[i], focals[i],
|
| 100 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
| 101 |
+
|
| 102 |
+
rot = np.eye(4)
|
| 103 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 104 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
| 105 |
+
if not silent:
|
| 106 |
+
print('(exporting 3D scene to', outfile, ')')
|
| 107 |
+
scene.export(file_obj=outfile)
|
| 108 |
+
return outfile
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
|
| 112 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
|
| 113 |
+
"""
|
| 114 |
+
extract 3D_model (glb file) from a reconstructed scene
|
| 115 |
+
"""
|
| 116 |
+
if scene_state is None:
|
| 117 |
+
return None
|
| 118 |
+
outfile = scene_state.outfile_name
|
| 119 |
+
if outfile is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
# get optimized values from scene
|
| 123 |
+
scene = scene_state.sparse_ga
|
| 124 |
+
rgbimg = scene.imgs
|
| 125 |
+
focals = scene.get_focals().cpu()
|
| 126 |
+
cams2world = scene.get_im_poses().cpu()
|
| 127 |
+
|
| 128 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
| 129 |
+
if TSDF_thresh > 0:
|
| 130 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
|
| 131 |
+
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
|
| 132 |
+
else:
|
| 133 |
+
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
|
| 134 |
+
msk = to_numpy([c > min_conf_thr for c in confs])
|
| 135 |
+
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
| 136 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
|
| 140 |
+
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
| 141 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
| 142 |
+
win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
| 143 |
+
"""
|
| 144 |
+
from a list of images, run mast3r inference, sparse global aligner.
|
| 145 |
+
then run get_3D_model_from_scene
|
| 146 |
+
"""
|
| 147 |
+
imgs = load_images(filelist, size=image_size, verbose=not silent)
|
| 148 |
+
if len(imgs) == 1:
|
| 149 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
| 150 |
+
imgs[1]['idx'] = 1
|
| 151 |
+
filelist = [filelist[0], filelist[0] + '_2']
|
| 152 |
+
|
| 153 |
+
scene_graph_params = [scenegraph_type]
|
| 154 |
+
if scenegraph_type in ["swin", "logwin"]:
|
| 155 |
+
scene_graph_params.append(str(winsize))
|
| 156 |
+
elif scenegraph_type == "oneref":
|
| 157 |
+
scene_graph_params.append(str(refid))
|
| 158 |
+
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
| 159 |
+
scene_graph_params.append('noncyclic')
|
| 160 |
+
scene_graph = '-'.join(scene_graph_params)
|
| 161 |
+
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
|
| 162 |
+
if optim_level == 'coarse':
|
| 163 |
+
niter2 = 0
|
| 164 |
+
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
| 165 |
+
if current_scene_state is not None and \
|
| 166 |
+
not current_scene_state.should_delete and \
|
| 167 |
+
current_scene_state.cache_dir is not None:
|
| 168 |
+
cache_dir = current_scene_state.cache_dir
|
| 169 |
+
elif gradio_delete_cache:
|
| 170 |
+
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
|
| 171 |
+
else:
|
| 172 |
+
cache_dir = os.path.join(outdir, 'cache')
|
| 173 |
+
scene = sparse_global_alignment(filelist, pairs, cache_dir,
|
| 174 |
+
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
|
| 175 |
+
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
| 176 |
+
matching_conf_thr=matching_conf_thr, **kw)
|
| 177 |
+
if current_scene_state is not None and \
|
| 178 |
+
not current_scene_state.should_delete and \
|
| 179 |
+
current_scene_state.outfile_name is not None:
|
| 180 |
+
outfile_name = current_scene_state.outfile_name
|
| 181 |
+
else:
|
| 182 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
| 183 |
+
|
| 184 |
+
scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
|
| 185 |
+
outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
|
| 186 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh)
|
| 187 |
+
return scene_state, outfile
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
| 191 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
| 192 |
+
show_win_controls = scenegraph_type in ["swin", "logwin"]
|
| 193 |
+
show_winsize = scenegraph_type in ["swin", "logwin"]
|
| 194 |
+
show_cyclic = scenegraph_type in ["swin", "logwin"]
|
| 195 |
+
max_winsize, min_winsize = 1, 1
|
| 196 |
+
if scenegraph_type == "swin":
|
| 197 |
+
if win_cyclic:
|
| 198 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
| 199 |
+
else:
|
| 200 |
+
max_winsize = num_files - 1
|
| 201 |
+
elif scenegraph_type == "logwin":
|
| 202 |
+
if win_cyclic:
|
| 203 |
+
half_size = math.ceil((num_files - 1) / 2)
|
| 204 |
+
max_winsize = max(1, math.ceil(math.log(half_size, 2)))
|
| 205 |
+
else:
|
| 206 |
+
max_winsize = max(1, math.ceil(math.log(num_files, 2)))
|
| 207 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 208 |
+
minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
|
| 209 |
+
win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
|
| 210 |
+
win_col = gradio.Column(visible=show_win_controls)
|
| 211 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 212 |
+
maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
|
| 213 |
+
return win_col, winsize, win_cyclic, refid
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
|
| 217 |
+
share=False, gradio_delete_cache=False):
|
| 218 |
+
if not silent:
|
| 219 |
+
print('Outputing stuff in', tmpdirname)
|
| 220 |
+
|
| 221 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
|
| 222 |
+
silent, image_size)
|
| 223 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
| 224 |
+
|
| 225 |
+
def get_context(delete_cache):
|
| 226 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
| 227 |
+
title = "MASt3R Demo"
|
| 228 |
+
if delete_cache:
|
| 229 |
+
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
|
| 230 |
+
else:
|
| 231 |
+
return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
|
| 232 |
+
|
| 233 |
+
with get_context(gradio_delete_cache) as demo:
|
| 234 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
| 235 |
+
scene = gradio.State(None)
|
| 236 |
+
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
| 237 |
+
with gradio.Column():
|
| 238 |
+
inputfiles = gradio.File(file_count="multiple")
|
| 239 |
+
with gradio.Row():
|
| 240 |
+
with gradio.Column():
|
| 241 |
+
with gradio.Row():
|
| 242 |
+
lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
|
| 243 |
+
niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
|
| 244 |
+
label="num_iterations", info="For coarse alignment!")
|
| 245 |
+
lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
|
| 246 |
+
niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
|
| 247 |
+
label="num_iterations", info="For refinement!")
|
| 248 |
+
optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
|
| 249 |
+
value='refine', label="OptLevel",
|
| 250 |
+
info="Optimization level")
|
| 251 |
+
with gradio.Row():
|
| 252 |
+
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
|
| 253 |
+
minimum=0., maximum=30., step=0.1,
|
| 254 |
+
info="Before Fallback to Regr3D!")
|
| 255 |
+
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
| 256 |
+
info="Only optimize one set of intrinsics for all views")
|
| 257 |
+
scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
|
| 258 |
+
("swin: sliding window", "swin"),
|
| 259 |
+
("logwin: sliding window with long range", "logwin"),
|
| 260 |
+
("oneref: match one image with all", "oneref")],
|
| 261 |
+
value='complete', label="Scenegraph",
|
| 262 |
+
info="Define how to make pairs",
|
| 263 |
+
interactive=True)
|
| 264 |
+
with gradio.Column(visible=False) as win_col:
|
| 265 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
| 266 |
+
minimum=1, maximum=1, step=1)
|
| 267 |
+
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
| 268 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
| 269 |
+
minimum=0, maximum=0, step=1, visible=False)
|
| 270 |
+
run_btn = gradio.Button("Run")
|
| 271 |
+
|
| 272 |
+
with gradio.Row():
|
| 273 |
+
# adjust the confidence threshold
|
| 274 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
|
| 275 |
+
# adjust the camera size in the output pointcloud
|
| 276 |
+
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
|
| 277 |
+
TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
|
| 278 |
+
with gradio.Row():
|
| 279 |
+
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
|
| 280 |
+
# two post process implemented
|
| 281 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
| 282 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
| 283 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
| 284 |
+
|
| 285 |
+
outmodel = gradio.Model3D()
|
| 286 |
+
|
| 287 |
+
# events
|
| 288 |
+
scenegraph_type.change(set_scenegraph_options,
|
| 289 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
| 290 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
| 291 |
+
inputfiles.change(set_scenegraph_options,
|
| 292 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
| 293 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
| 294 |
+
win_cyclic.change(set_scenegraph_options,
|
| 295 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
| 296 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
| 297 |
+
run_btn.click(fn=recon_fun,
|
| 298 |
+
inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
| 299 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
| 300 |
+
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
|
| 301 |
+
outputs=[scene, outmodel])
|
| 302 |
+
min_conf_thr.release(fn=model_from_scene_fun,
|
| 303 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 304 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 305 |
+
outputs=outmodel)
|
| 306 |
+
cam_size.change(fn=model_from_scene_fun,
|
| 307 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 308 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 309 |
+
outputs=outmodel)
|
| 310 |
+
TSDF_thresh.change(fn=model_from_scene_fun,
|
| 311 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 312 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 313 |
+
outputs=outmodel)
|
| 314 |
+
as_pointcloud.change(fn=model_from_scene_fun,
|
| 315 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 316 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 317 |
+
outputs=outmodel)
|
| 318 |
+
mask_sky.change(fn=model_from_scene_fun,
|
| 319 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 320 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 321 |
+
outputs=outmodel)
|
| 322 |
+
clean_depth.change(fn=model_from_scene_fun,
|
| 323 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 324 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 325 |
+
outputs=outmodel)
|
| 326 |
+
transparent_cams.change(model_from_scene_fun,
|
| 327 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 328 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
| 329 |
+
outputs=outmodel)
|
| 330 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
misc.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilitary functions for MASt3R
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import os
|
| 8 |
+
import hashlib
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def mkdir_for(f):
|
| 12 |
+
os.makedirs(os.path.dirname(f), exist_ok=True)
|
| 13 |
+
return f
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def hash_md5(s):
|
| 17 |
+
return hashlib.md5(s.encode('utf-8')).hexdigest()
|
model.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# MASt3R model class
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from mast3r.catmlp_dpt_head import mast3r_head_factory
|
| 12 |
+
|
| 13 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 14 |
+
from dust3r.model import AsymmetricCroCo3DStereo # noqa
|
| 15 |
+
from dust3r.utils.misc import transpose_to_landscape # noqa
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
inf = float('inf')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_model(model_path, device, verbose=True):
|
| 22 |
+
if verbose:
|
| 23 |
+
print('... loading model from', model_path)
|
| 24 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 25 |
+
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
| 26 |
+
if 'landscape_only' not in args:
|
| 27 |
+
args = args[:-1] + ', landscape_only=False)'
|
| 28 |
+
else:
|
| 29 |
+
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
|
| 30 |
+
assert "landscape_only=False" in args
|
| 31 |
+
if verbose:
|
| 32 |
+
print(f"instantiating : {args}")
|
| 33 |
+
net = eval(args)
|
| 34 |
+
s = net.load_state_dict(ckpt['model'], strict=False)
|
| 35 |
+
if verbose:
|
| 36 |
+
print(s)
|
| 37 |
+
return net.to(device)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
|
| 41 |
+
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
|
| 42 |
+
self.desc_mode = desc_mode
|
| 43 |
+
self.two_confs = two_confs
|
| 44 |
+
self.desc_conf_mode = desc_conf_mode
|
| 45 |
+
super().__init__(**kwargs)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
| 49 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
| 50 |
+
return load_model(pretrained_model_name_or_path, device='cpu')
|
| 51 |
+
else:
|
| 52 |
+
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
|
| 53 |
+
|
| 54 |
+
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
|
| 55 |
+
assert img_size[0] % patch_size == 0 and img_size[
|
| 56 |
+
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
|
| 57 |
+
self.output_mode = output_mode
|
| 58 |
+
self.head_type = head_type
|
| 59 |
+
self.depth_mode = depth_mode
|
| 60 |
+
self.conf_mode = conf_mode
|
| 61 |
+
if self.desc_conf_mode is None:
|
| 62 |
+
self.desc_conf_mode = conf_mode
|
| 63 |
+
# allocate heads
|
| 64 |
+
self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 65 |
+
self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 66 |
+
# magic wrapper
|
| 67 |
+
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
|
| 68 |
+
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
|
sparse_ga.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# MASt3R Sparse Global Alignement
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import roma
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import numpy as np
|
| 13 |
+
import os
|
| 14 |
+
from collections import namedtuple
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
from scipy import sparse as sp
|
| 17 |
+
|
| 18 |
+
from mast3r.utils.misc import mkdir_for, hash_md5
|
| 19 |
+
from mast3r.cloud_opt.utils.losses import gamma_loss
|
| 20 |
+
from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
|
| 21 |
+
from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
|
| 22 |
+
|
| 23 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 24 |
+
from dust3r.utils.geometry import inv, geotrf # noqa
|
| 25 |
+
from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
|
| 26 |
+
from dust3r.post_process import estimate_focal_knowing_depth # noqa
|
| 27 |
+
from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
|
| 28 |
+
from dust3r.cloud_opt.base_opt import clean_pointcloud
|
| 29 |
+
from dust3r.viz import SceneViz
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SparseGA():
|
| 33 |
+
def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
|
| 34 |
+
def fetch_img(im):
|
| 35 |
+
def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
|
| 36 |
+
for im1, im2 in pairs_in:
|
| 37 |
+
if im1['instance'] == im:
|
| 38 |
+
return torgb(im1['img'])
|
| 39 |
+
if im2['instance'] == im:
|
| 40 |
+
return torgb(im2['img'])
|
| 41 |
+
self.canonical_paths = canonical_paths
|
| 42 |
+
self.img_paths = img_paths
|
| 43 |
+
self.imgs = [fetch_img(img) for img in img_paths]
|
| 44 |
+
self.intrinsics = res_fine['intrinsics']
|
| 45 |
+
self.cam2w = res_fine['cam2w']
|
| 46 |
+
self.depthmaps = res_fine['depthmaps']
|
| 47 |
+
self.pts3d = res_fine['pts3d']
|
| 48 |
+
self.pts3d_colors = []
|
| 49 |
+
self.working_device = self.cam2w.device
|
| 50 |
+
for i in range(len(self.imgs)):
|
| 51 |
+
im = self.imgs[i]
|
| 52 |
+
x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
|
| 53 |
+
self.pts3d_colors.append(im[y, x])
|
| 54 |
+
assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
|
| 55 |
+
self.n_imgs = len(self.imgs)
|
| 56 |
+
|
| 57 |
+
def get_focals(self):
|
| 58 |
+
return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
|
| 59 |
+
|
| 60 |
+
def get_principal_points(self):
|
| 61 |
+
return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
|
| 62 |
+
|
| 63 |
+
def get_im_poses(self):
|
| 64 |
+
return self.cam2w
|
| 65 |
+
|
| 66 |
+
def get_sparse_pts3d(self):
|
| 67 |
+
return self.pts3d
|
| 68 |
+
|
| 69 |
+
def get_dense_pts3d(self, clean_depth=True, subsample=8):
|
| 70 |
+
assert self.canonical_paths, 'cache_path is required for dense 3d points'
|
| 71 |
+
device = self.cam2w.device
|
| 72 |
+
confs = []
|
| 73 |
+
base_focals = []
|
| 74 |
+
anchors = {}
|
| 75 |
+
for i, canon_path in enumerate(self.canonical_paths):
|
| 76 |
+
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
| 77 |
+
confs.append(conf)
|
| 78 |
+
base_focals.append(focal)
|
| 79 |
+
|
| 80 |
+
H, W = conf.shape
|
| 81 |
+
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
| 82 |
+
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
| 83 |
+
anchors[i] = (pixels, idxs[i], offsets[i])
|
| 84 |
+
|
| 85 |
+
# densify sparse depthmaps
|
| 86 |
+
pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
|
| 87 |
+
d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
|
| 88 |
+
|
| 89 |
+
if clean_depth:
|
| 90 |
+
confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d)
|
| 91 |
+
|
| 92 |
+
return pts3d, depthmaps, confs
|
| 93 |
+
|
| 94 |
+
def get_pts3d_colors(self):
|
| 95 |
+
return self.pts3d_colors
|
| 96 |
+
|
| 97 |
+
def get_depthmaps(self):
|
| 98 |
+
return self.depthmaps
|
| 99 |
+
|
| 100 |
+
def get_masks(self):
|
| 101 |
+
return [slice(None, None) for _ in range(len(self.imgs))]
|
| 102 |
+
|
| 103 |
+
def show(self, show_cams=True):
|
| 104 |
+
pts3d, _, confs = self.get_dense_pts3d()
|
| 105 |
+
show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
|
| 106 |
+
[p.clip(min=-50, max=50) for p in pts3d],
|
| 107 |
+
masks=[c > 1 for c in confs])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def convert_dust3r_pairs_naming(imgs, pairs_in):
|
| 111 |
+
for pair_id in range(len(pairs_in)):
|
| 112 |
+
for i in range(2):
|
| 113 |
+
pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
|
| 114 |
+
return pairs_in
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
|
| 118 |
+
device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
|
| 119 |
+
""" Sparse alignment with MASt3R
|
| 120 |
+
imgs: list of image paths
|
| 121 |
+
cache_path: path where to dump temporary files (str)
|
| 122 |
+
|
| 123 |
+
lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
|
| 124 |
+
lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
|
| 125 |
+
|
| 126 |
+
lora_depth: smart dimensionality reduction with depthmaps
|
| 127 |
+
"""
|
| 128 |
+
# Convert pair naming convention from dust3r to mast3r
|
| 129 |
+
pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
|
| 130 |
+
# forward pass
|
| 131 |
+
pairs, cache_path = forward_mast3r(pairs_in, model,
|
| 132 |
+
cache_path=cache_path, subsample=subsample,
|
| 133 |
+
desc_conf=desc_conf, device=device)
|
| 134 |
+
|
| 135 |
+
# extract canonical pointmaps
|
| 136 |
+
tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
|
| 137 |
+
prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
|
| 138 |
+
|
| 139 |
+
# compute minimal spanning tree
|
| 140 |
+
mst = compute_min_spanning_tree(pairwise_scores)
|
| 141 |
+
|
| 142 |
+
# remove all edges not in the spanning tree?
|
| 143 |
+
# min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
|
| 144 |
+
# tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
|
| 145 |
+
|
| 146 |
+
# smartly combine all useful data
|
| 147 |
+
imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
|
| 148 |
+
condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
|
| 149 |
+
|
| 150 |
+
imgs, res_coarse, res_fine = sparse_scene_optimizer(
|
| 151 |
+
imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
|
| 152 |
+
shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
|
| 153 |
+
|
| 154 |
+
return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
|
| 158 |
+
preds_21, canonical_paths, mst, cache_path,
|
| 159 |
+
lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
|
| 160 |
+
lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
|
| 161 |
+
lossd=gamma_loss(1.1),
|
| 162 |
+
opt_pp=True, opt_depth=True,
|
| 163 |
+
schedule=cosine_schedule, depth_mode='add', exp_depth=False,
|
| 164 |
+
lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
|
| 165 |
+
shared_intrinsics=False,
|
| 166 |
+
init={}, device='cuda', dtype=torch.float32,
|
| 167 |
+
matching_conf_thr=5., loss_dust3r_w=0.01,
|
| 168 |
+
verbose=True, dbg=()):
|
| 169 |
+
|
| 170 |
+
# extrinsic parameters
|
| 171 |
+
vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
|
| 172 |
+
quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
|
| 173 |
+
trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
|
| 174 |
+
|
| 175 |
+
# intialize
|
| 176 |
+
ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
|
| 177 |
+
median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
|
| 178 |
+
for img in imgs:
|
| 179 |
+
idx = imgs.index(img)
|
| 180 |
+
init_values = init.setdefault(img, {})
|
| 181 |
+
if verbose and init_values:
|
| 182 |
+
print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
|
| 183 |
+
|
| 184 |
+
K = init_values.get('intrinsics')
|
| 185 |
+
if K is not None:
|
| 186 |
+
K = K.detach()
|
| 187 |
+
focal = K[:2, :2].diag().mean()
|
| 188 |
+
pp = K[:2, 2]
|
| 189 |
+
base_focals[idx] = focal
|
| 190 |
+
pps[idx] = pp
|
| 191 |
+
pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
|
| 192 |
+
|
| 193 |
+
depth = init_values.get('depthmap')
|
| 194 |
+
if depth is not None:
|
| 195 |
+
core_depth[idx] = depth.detach()
|
| 196 |
+
|
| 197 |
+
median_depths[idx] = med_depth = core_depth[idx].median()
|
| 198 |
+
core_depth[idx] /= med_depth
|
| 199 |
+
|
| 200 |
+
cam2w = init_values.get('cam2w')
|
| 201 |
+
if cam2w is not None:
|
| 202 |
+
rot = cam2w[:3, :3].detach()
|
| 203 |
+
cam_center = cam2w[:3, 3].detach()
|
| 204 |
+
quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
|
| 205 |
+
trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
|
| 206 |
+
trans[idx].data[:] = cam_center + rot @ trans_offset
|
| 207 |
+
del rot
|
| 208 |
+
assert False, 'inverse kinematic chain not yet implemented'
|
| 209 |
+
|
| 210 |
+
# intrinsics parameters
|
| 211 |
+
if shared_intrinsics:
|
| 212 |
+
# Optimize a single set of intrinsics for all cameras. Use averages as init.
|
| 213 |
+
confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps)
|
| 214 |
+
weighting = confs / confs.sum()
|
| 215 |
+
pp = nn.Parameter((weighting @ pps).to(dtype))
|
| 216 |
+
pps = [pp for _ in range(len(imgs))]
|
| 217 |
+
focal_m = weighting @ base_focals
|
| 218 |
+
log_focal = nn.Parameter(focal_m.view(1).log().to(dtype))
|
| 219 |
+
log_focals = [log_focal for _ in range(len(imgs))]
|
| 220 |
+
else:
|
| 221 |
+
pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
|
| 222 |
+
log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
|
| 223 |
+
|
| 224 |
+
diags = imsizes.float().norm(dim=1)
|
| 225 |
+
min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
|
| 226 |
+
max_focals = 10 * diags
|
| 227 |
+
|
| 228 |
+
assert len(mst[1]) == len(pps) - 1
|
| 229 |
+
|
| 230 |
+
def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
|
| 231 |
+
# make intrinsics
|
| 232 |
+
focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
|
| 233 |
+
pps = torch.stack(pps)
|
| 234 |
+
K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
|
| 235 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
| 236 |
+
K[:, 0:2, 2] = pps * imsizes
|
| 237 |
+
if trans is None:
|
| 238 |
+
return K
|
| 239 |
+
|
| 240 |
+
# security! optimization is always trying to crush the scale down
|
| 241 |
+
sizes = torch.cat(log_sizes).exp()
|
| 242 |
+
global_scaling = 1 / sizes.min()
|
| 243 |
+
|
| 244 |
+
# compute distance of camera to focal plane
|
| 245 |
+
# tan(fov) = W/2 / focal
|
| 246 |
+
z_cameras = sizes * median_depths * focals / base_focals
|
| 247 |
+
|
| 248 |
+
# make extrinsic
|
| 249 |
+
rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
|
| 250 |
+
rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
|
| 251 |
+
rel_cam2cam[:, :3, 3] = torch.stack(trans)
|
| 252 |
+
|
| 253 |
+
# camera are defined as a kinematic chain
|
| 254 |
+
tmp_cam2w = [None] * len(K)
|
| 255 |
+
tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
|
| 256 |
+
for i, j in mst[1]:
|
| 257 |
+
# i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
|
| 258 |
+
tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
|
| 259 |
+
tmp_cam2w = torch.stack(tmp_cam2w)
|
| 260 |
+
|
| 261 |
+
# smart reparameterizaton of cameras
|
| 262 |
+
trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
|
| 263 |
+
new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
|
| 264 |
+
cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
|
| 265 |
+
vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
|
| 266 |
+
|
| 267 |
+
depthmaps = []
|
| 268 |
+
for i in range(len(imgs)):
|
| 269 |
+
core_depth_img = core_depth[i]
|
| 270 |
+
if exp_depth:
|
| 271 |
+
core_depth_img = core_depth_img.exp()
|
| 272 |
+
if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
|
| 273 |
+
core_depth_img = lora_depth_proj[i] @ core_depth_img
|
| 274 |
+
if depth_mode == 'add':
|
| 275 |
+
core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
|
| 276 |
+
elif depth_mode == 'mul':
|
| 277 |
+
core_depth_img = z_cameras[i] * core_depth_img
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f'Bad {depth_mode=}')
|
| 280 |
+
depthmaps.append(global_scaling * core_depth_img)
|
| 281 |
+
|
| 282 |
+
return K, (inv(cam2w), cam2w), depthmaps
|
| 283 |
+
|
| 284 |
+
K = make_K_cam_depth(log_focals, pps, None, None, None, None)
|
| 285 |
+
|
| 286 |
+
if shared_intrinsics:
|
| 287 |
+
print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
|
| 288 |
+
else:
|
| 289 |
+
print('init focals =', to_numpy(K[:, 0, 0]))
|
| 290 |
+
|
| 291 |
+
# spectral low-rank projection of depthmaps
|
| 292 |
+
if lora_depth:
|
| 293 |
+
core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
|
| 294 |
+
imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
|
| 295 |
+
if exp_depth:
|
| 296 |
+
core_depth = [d.clip(min=1e-4).log() for d in core_depth]
|
| 297 |
+
core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
|
| 298 |
+
log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
|
| 299 |
+
|
| 300 |
+
# Fetch img slices
|
| 301 |
+
_, confs_sum, imgs_slices = corres
|
| 302 |
+
|
| 303 |
+
# Define which pairs are fine to use with matching
|
| 304 |
+
def matching_check(x): return x.max() > matching_conf_thr
|
| 305 |
+
is_matching_ok = {}
|
| 306 |
+
for s in imgs_slices:
|
| 307 |
+
is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
|
| 308 |
+
|
| 309 |
+
# Prepare slices and corres for losses
|
| 310 |
+
dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
|
| 311 |
+
loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
|
| 312 |
+
cleaned_corres2d = []
|
| 313 |
+
for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d):
|
| 314 |
+
cf_sum = 0
|
| 315 |
+
pix1_filtered = []
|
| 316 |
+
confs_filtered = []
|
| 317 |
+
curstep = 0
|
| 318 |
+
cleaned_slices = []
|
| 319 |
+
for img2, slice2 in imgs_slices:
|
| 320 |
+
if is_matching_ok[img1, img2]:
|
| 321 |
+
tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
|
| 322 |
+
pix1_filtered.append(pix1[tslice])
|
| 323 |
+
confs_filtered.append(confs[tslice])
|
| 324 |
+
cleaned_slices.append((img2, slice2))
|
| 325 |
+
curstep += slice2.stop - slice2.start
|
| 326 |
+
if pix1_filtered != []:
|
| 327 |
+
pix1_filtered = torch.cat(pix1_filtered)
|
| 328 |
+
confs_filtered = torch.cat(confs_filtered)
|
| 329 |
+
cf_sum = confs_filtered.sum()
|
| 330 |
+
cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices))
|
| 331 |
+
|
| 332 |
+
def loss_dust3r(cam2w, pts3d, pix_loss):
|
| 333 |
+
# In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
|
| 334 |
+
loss = 0.
|
| 335 |
+
cf_sum = 0.
|
| 336 |
+
for s in dust3r_slices:
|
| 337 |
+
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
| 338 |
+
continue
|
| 339 |
+
# fallback to dust3r regression
|
| 340 |
+
tgt_pts, tgt_confs = preds_21[imgs[s.img2]][imgs[s.img1]]
|
| 341 |
+
tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
|
| 342 |
+
cf_sum += tgt_confs.sum()
|
| 343 |
+
loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
|
| 344 |
+
return loss / cf_sum if cf_sum != 0. else 0.
|
| 345 |
+
|
| 346 |
+
def loss_3d(K, w2cam, pts3d, pix_loss):
|
| 347 |
+
# For each correspondence, we have two 3D points (one for each image of the pair).
|
| 348 |
+
# For each 3D point, we have 2 reproj errors
|
| 349 |
+
if any(v.get('freeze') for v in init.values()):
|
| 350 |
+
pts3d_1 = []
|
| 351 |
+
pts3d_2 = []
|
| 352 |
+
confs = []
|
| 353 |
+
for s in loss3d_slices:
|
| 354 |
+
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
| 355 |
+
continue
|
| 356 |
+
pts3d_1.append(pts3d[s.img1][s.slice1])
|
| 357 |
+
pts3d_2.append(pts3d[s.img2][s.slice2])
|
| 358 |
+
confs.append(s.confs)
|
| 359 |
+
else:
|
| 360 |
+
pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices]
|
| 361 |
+
pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices]
|
| 362 |
+
confs = [s.confs for s in loss3d_slices]
|
| 363 |
+
|
| 364 |
+
if pts3d_1 != []:
|
| 365 |
+
confs = torch.cat(confs)
|
| 366 |
+
pts3d_1 = torch.cat(pts3d_1)
|
| 367 |
+
pts3d_2 = torch.cat(pts3d_2)
|
| 368 |
+
loss = confs @ pix_loss(pts3d_1, pts3d_2)
|
| 369 |
+
cf_sum = confs.sum()
|
| 370 |
+
else:
|
| 371 |
+
loss = 0.
|
| 372 |
+
cf_sum = 1.
|
| 373 |
+
|
| 374 |
+
return loss / cf_sum
|
| 375 |
+
|
| 376 |
+
def loss_2d(K, w2cam, pts3d, pix_loss):
|
| 377 |
+
# For each correspondence, we have two 3D points (one for each image of the pair).
|
| 378 |
+
# For each 3D point, we have 2 reproj errors
|
| 379 |
+
proj_matrix = K @ w2cam[:, :3]
|
| 380 |
+
loss = npix = 0
|
| 381 |
+
for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d:
|
| 382 |
+
if init[imgs[img1]].get('freeze', 0) >= 1:
|
| 383 |
+
continue # no need
|
| 384 |
+
pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices]
|
| 385 |
+
if pts3d_in_img1 != []:
|
| 386 |
+
pts3d_in_img1 = torch.cat(pts3d_in_img1)
|
| 387 |
+
loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
|
| 388 |
+
npix += confs_filtered.sum()
|
| 389 |
+
|
| 390 |
+
return loss / npix if npix != 0 else 0.
|
| 391 |
+
|
| 392 |
+
def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
|
| 393 |
+
# create optimizer
|
| 394 |
+
params = pps + log_focals + quats + trans + log_sizes + core_depth
|
| 395 |
+
optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
|
| 396 |
+
ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
|
| 397 |
+
|
| 398 |
+
with tqdm(total=niter) as bar:
|
| 399 |
+
for iter in range(niter or 1):
|
| 400 |
+
K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
|
| 401 |
+
pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
|
| 402 |
+
if niter == 0:
|
| 403 |
+
break
|
| 404 |
+
|
| 405 |
+
alpha = (iter / niter)
|
| 406 |
+
lr = schedule(alpha, lr_base, lr_end)
|
| 407 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
| 408 |
+
pix_loss = ploss(1 - alpha)
|
| 409 |
+
optimizer.zero_grad()
|
| 410 |
+
loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
|
| 411 |
+
loss.backward()
|
| 412 |
+
optimizer.step()
|
| 413 |
+
|
| 414 |
+
# make sure the pose remains well optimizable
|
| 415 |
+
for i in range(len(imgs)):
|
| 416 |
+
quats[i].data[:] /= quats[i].data.norm()
|
| 417 |
+
|
| 418 |
+
loss = float(loss)
|
| 419 |
+
if loss != loss:
|
| 420 |
+
break # NaN loss
|
| 421 |
+
bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
|
| 422 |
+
bar.update(1)
|
| 423 |
+
|
| 424 |
+
if niter:
|
| 425 |
+
print(f'>> final loss = {loss}')
|
| 426 |
+
return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
|
| 427 |
+
depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
|
| 428 |
+
|
| 429 |
+
# at start, don't optimize 3d points
|
| 430 |
+
for i, img in enumerate(imgs):
|
| 431 |
+
trainable = not (init[img].get('freeze'))
|
| 432 |
+
pps[i].requires_grad_(False)
|
| 433 |
+
log_focals[i].requires_grad_(False)
|
| 434 |
+
quats[i].requires_grad_(trainable)
|
| 435 |
+
trans[i].requires_grad_(trainable)
|
| 436 |
+
log_sizes[i].requires_grad_(trainable)
|
| 437 |
+
core_depth[i].requires_grad_(False)
|
| 438 |
+
|
| 439 |
+
res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
|
| 440 |
+
|
| 441 |
+
res_fine = None
|
| 442 |
+
if niter2:
|
| 443 |
+
# now we can optimize 3d points
|
| 444 |
+
for i, img in enumerate(imgs):
|
| 445 |
+
if init[img].get('freeze', 0) >= 1:
|
| 446 |
+
continue
|
| 447 |
+
pps[i].requires_grad_(bool(opt_pp))
|
| 448 |
+
log_focals[i].requires_grad_(True)
|
| 449 |
+
core_depth[i].requires_grad_(opt_depth)
|
| 450 |
+
|
| 451 |
+
# refinement with 2d reproj
|
| 452 |
+
res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
|
| 453 |
+
|
| 454 |
+
K = make_K_cam_depth(log_focals, pps, None, None, None, None)
|
| 455 |
+
if shared_intrinsics:
|
| 456 |
+
print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
|
| 457 |
+
else:
|
| 458 |
+
print('Final focals =', to_numpy(K[:, 0, 0]))
|
| 459 |
+
|
| 460 |
+
return imgs, res_coarse, res_fine
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
@lru_cache
|
| 464 |
+
def mask110(device, dtype):
|
| 465 |
+
return torch.tensor((1, 1, 0), device=device, dtype=dtype)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def proj3d(inv_K, pixels, z):
|
| 469 |
+
if pixels.shape[-1] == 2:
|
| 470 |
+
pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
|
| 471 |
+
return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
|
| 475 |
+
focals = K[:, 0, 0]
|
| 476 |
+
invK = inv(K)
|
| 477 |
+
all_pts3d = []
|
| 478 |
+
depth_out = []
|
| 479 |
+
|
| 480 |
+
for img, (pixels, idxs, offsets) in anchors.items():
|
| 481 |
+
# from depthmaps to 3d points
|
| 482 |
+
if base_focals is None:
|
| 483 |
+
pass
|
| 484 |
+
else:
|
| 485 |
+
# compensate for focal
|
| 486 |
+
# depth + depth * (offset - 1) * base_focal / focal
|
| 487 |
+
# = depth * (1 + (offset - 1) * (base_focal / focal))
|
| 488 |
+
offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
|
| 489 |
+
|
| 490 |
+
pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
|
| 491 |
+
if ret_depth:
|
| 492 |
+
depth_out.append(pts3d[..., 2]) # before camera rotation
|
| 493 |
+
|
| 494 |
+
# rotate to world coordinate
|
| 495 |
+
pts3d = geotrf(cam2w[img], pts3d)
|
| 496 |
+
all_pts3d.append(pts3d)
|
| 497 |
+
|
| 498 |
+
if ret_depth:
|
| 499 |
+
return all_pts3d, depth_out
|
| 500 |
+
return all_pts3d
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
|
| 504 |
+
base_focals = []
|
| 505 |
+
anchors = {}
|
| 506 |
+
confs = []
|
| 507 |
+
for i, canon_path in enumerate(canonical_paths):
|
| 508 |
+
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
| 509 |
+
confs.append(conf)
|
| 510 |
+
base_focals.append(focal)
|
| 511 |
+
H, W = conf.shape
|
| 512 |
+
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
| 513 |
+
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
| 514 |
+
anchors[i] = (pixels, idxs[i], offsets[i])
|
| 515 |
+
|
| 516 |
+
# densify sparse depthmaps
|
| 517 |
+
pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
|
| 518 |
+
d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
|
| 519 |
+
|
| 520 |
+
return pts3d, depthmaps_out, confs
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@torch.no_grad()
|
| 524 |
+
def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
|
| 525 |
+
device='cuda', subsample=8, **matching_kw):
|
| 526 |
+
res_paths = {}
|
| 527 |
+
|
| 528 |
+
for img1, img2 in tqdm(pairs):
|
| 529 |
+
idx1 = hash_md5(img1['instance'])
|
| 530 |
+
idx2 = hash_md5(img2['instance'])
|
| 531 |
+
|
| 532 |
+
path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
|
| 533 |
+
path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
|
| 534 |
+
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
|
| 535 |
+
path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
|
| 536 |
+
|
| 537 |
+
if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
|
| 538 |
+
score, (xy1, xy2, confs) = torch.load(path_corres2)
|
| 539 |
+
torch.save((score, (xy2, xy1, confs)), path_corres)
|
| 540 |
+
|
| 541 |
+
if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
|
| 542 |
+
if model is None:
|
| 543 |
+
continue
|
| 544 |
+
res = symmetric_inference(model, img1, img2, device=device)
|
| 545 |
+
X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
|
| 546 |
+
C11, C21, C22, C12 = [r['conf'][0] for r in res]
|
| 547 |
+
descs = [r['desc'][0] for r in res]
|
| 548 |
+
qonfs = [r[desc_conf][0] for r in res]
|
| 549 |
+
|
| 550 |
+
# save
|
| 551 |
+
torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
|
| 552 |
+
torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
|
| 553 |
+
|
| 554 |
+
# perform reciprocal matching
|
| 555 |
+
corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
|
| 556 |
+
|
| 557 |
+
conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
|
| 558 |
+
matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
|
| 559 |
+
if cache_path is not None:
|
| 560 |
+
torch.save((matching_score, corres), mkdir_for(path_corres))
|
| 561 |
+
|
| 562 |
+
res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
|
| 563 |
+
|
| 564 |
+
del model
|
| 565 |
+
torch.cuda.empty_cache()
|
| 566 |
+
|
| 567 |
+
return res_paths, cache_path
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def symmetric_inference(model, img1, img2, device):
|
| 571 |
+
shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
|
| 572 |
+
shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
|
| 573 |
+
img1 = img1['img'].to(device, non_blocking=True)
|
| 574 |
+
img2 = img2['img'].to(device, non_blocking=True)
|
| 575 |
+
|
| 576 |
+
# compute encoder only once
|
| 577 |
+
feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
|
| 578 |
+
|
| 579 |
+
def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
|
| 580 |
+
dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
|
| 581 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 582 |
+
res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
| 583 |
+
res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
| 584 |
+
return res1, res2
|
| 585 |
+
|
| 586 |
+
# decoder 1-2
|
| 587 |
+
res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
|
| 588 |
+
# decoder 2-1
|
| 589 |
+
res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
|
| 590 |
+
|
| 591 |
+
return (res11, res21, res22, res12)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
|
| 595 |
+
feat11, feat21, feat22, feat12 = feats
|
| 596 |
+
qonf11, qonf21, qonf22, qonf12 = qonfs
|
| 597 |
+
assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
|
| 598 |
+
assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
|
| 599 |
+
|
| 600 |
+
if '3d' in ptmap_key:
|
| 601 |
+
opt = dict(device='cpu', workers=32)
|
| 602 |
+
else:
|
| 603 |
+
opt = dict(device=device, dist='dot', block_size=2**13)
|
| 604 |
+
|
| 605 |
+
# matching the two pairs
|
| 606 |
+
idx1 = []
|
| 607 |
+
idx2 = []
|
| 608 |
+
qonf1 = []
|
| 609 |
+
qonf2 = []
|
| 610 |
+
# TODO add non symmetric / pixel_tol options
|
| 611 |
+
for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
|
| 612 |
+
(feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
|
| 613 |
+
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 614 |
+
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
| 615 |
+
|
| 616 |
+
idx1.append(np.r_[nn1to2[0], nn2to1[1]])
|
| 617 |
+
idx2.append(np.r_[nn1to2[1], nn2to1[0]])
|
| 618 |
+
qonf1.append(QA.ravel()[idx1[-1]])
|
| 619 |
+
qonf2.append(QB.ravel()[idx2[-1]])
|
| 620 |
+
|
| 621 |
+
# merge corres from opposite pairs
|
| 622 |
+
H1, W1 = feat11.shape[:2]
|
| 623 |
+
H2, W2 = feat22.shape[:2]
|
| 624 |
+
cat = np.concatenate
|
| 625 |
+
|
| 626 |
+
xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
|
| 627 |
+
corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
|
| 628 |
+
|
| 629 |
+
return todevice(corres, device)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
@torch.no_grad()
|
| 633 |
+
def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
|
| 634 |
+
cache_path=None, device='cuda', **kw):
|
| 635 |
+
canonical_views = {}
|
| 636 |
+
pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
|
| 637 |
+
canonical_paths = []
|
| 638 |
+
preds_21 = {}
|
| 639 |
+
|
| 640 |
+
for img in tqdm(imgs):
|
| 641 |
+
if cache_path:
|
| 642 |
+
cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
|
| 643 |
+
canonical_paths.append(cache)
|
| 644 |
+
try:
|
| 645 |
+
(canon, canon2, cconf), focal = torch.load(cache, map_location=device)
|
| 646 |
+
except IOError:
|
| 647 |
+
# cache does not exist yet, we create it!
|
| 648 |
+
canon = focal = None
|
| 649 |
+
|
| 650 |
+
# collect all pred1
|
| 651 |
+
n_pairs = sum((img in pair) for pair in tmp_pairs)
|
| 652 |
+
|
| 653 |
+
ptmaps11 = None
|
| 654 |
+
pixels = {}
|
| 655 |
+
n = 0
|
| 656 |
+
for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
|
| 657 |
+
score = None
|
| 658 |
+
if img == img1:
|
| 659 |
+
X, C, X2, C2 = torch.load(path1, map_location=device)
|
| 660 |
+
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
| 661 |
+
pixels[img2] = xy1, confs
|
| 662 |
+
if img not in preds_21:
|
| 663 |
+
preds_21[img] = {}
|
| 664 |
+
# Subsample preds_21
|
| 665 |
+
preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
| 666 |
+
|
| 667 |
+
if img == img2:
|
| 668 |
+
X, C, X2, C2 = torch.load(path2, map_location=device)
|
| 669 |
+
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
| 670 |
+
pixels[img1] = xy2, confs
|
| 671 |
+
if img not in preds_21:
|
| 672 |
+
preds_21[img] = {}
|
| 673 |
+
preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
| 674 |
+
|
| 675 |
+
if score is not None:
|
| 676 |
+
i, j = imgs.index(img1), imgs.index(img2)
|
| 677 |
+
# score = score[0]
|
| 678 |
+
# score = np.log1p(score[2])
|
| 679 |
+
score = score[2]
|
| 680 |
+
pairwise_scores[i, j] = score
|
| 681 |
+
pairwise_scores[j, i] = score
|
| 682 |
+
|
| 683 |
+
if canon is not None:
|
| 684 |
+
continue
|
| 685 |
+
if ptmaps11 is None:
|
| 686 |
+
H, W = C.shape
|
| 687 |
+
ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
|
| 688 |
+
confs11 = torch.empty((n_pairs, H, W), device=device)
|
| 689 |
+
|
| 690 |
+
ptmaps11[n] = X
|
| 691 |
+
confs11[n] = C
|
| 692 |
+
n += 1
|
| 693 |
+
|
| 694 |
+
if canon is None:
|
| 695 |
+
canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
|
| 696 |
+
del ptmaps11
|
| 697 |
+
del confs11
|
| 698 |
+
|
| 699 |
+
# compute focals
|
| 700 |
+
H, W = canon.shape[:2]
|
| 701 |
+
pp = torch.tensor([W / 2, H / 2], device=device)
|
| 702 |
+
if focal is None:
|
| 703 |
+
focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
|
| 704 |
+
if cache:
|
| 705 |
+
torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
|
| 706 |
+
|
| 707 |
+
# extract depth offsets with correspondences
|
| 708 |
+
core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
|
| 709 |
+
idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
|
| 710 |
+
|
| 711 |
+
canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
|
| 712 |
+
|
| 713 |
+
return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def load_corres(path_corres, device, min_conf_thr):
|
| 717 |
+
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
|
| 718 |
+
valid = confs > min_conf_thr if min_conf_thr else slice(None)
|
| 719 |
+
# valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
|
| 720 |
+
# print(f'keeping {valid.sum()} / {len(valid)} correspondences')
|
| 721 |
+
return score, (xy1[valid], xy2[valid], confs[valid])
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
PairOfSlices = namedtuple(
|
| 725 |
+
'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def condense_data(imgs, tmp_paths, canonical_views, preds_21, dtype=torch.float32):
|
| 729 |
+
# aggregate all data properly
|
| 730 |
+
set_imgs = set(imgs)
|
| 731 |
+
|
| 732 |
+
principal_points = []
|
| 733 |
+
shapes = []
|
| 734 |
+
focals = []
|
| 735 |
+
core_depth = []
|
| 736 |
+
img_anchors = {}
|
| 737 |
+
tmp_pixels = {}
|
| 738 |
+
|
| 739 |
+
for idx1, img1 in enumerate(imgs):
|
| 740 |
+
# load stuff
|
| 741 |
+
pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
|
| 742 |
+
|
| 743 |
+
principal_points.append(pp)
|
| 744 |
+
shapes.append(shape)
|
| 745 |
+
focals.append(focal)
|
| 746 |
+
core_depth.append(anchors)
|
| 747 |
+
|
| 748 |
+
img_uv1 = []
|
| 749 |
+
img_idxs = []
|
| 750 |
+
img_offs = []
|
| 751 |
+
cur_n = [0]
|
| 752 |
+
|
| 753 |
+
for img2, (pixels, match_confs) in pixels_confs.items():
|
| 754 |
+
if img2 not in set_imgs:
|
| 755 |
+
continue
|
| 756 |
+
assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
|
| 757 |
+
img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
|
| 758 |
+
img_idxs.append(idxs[img2])
|
| 759 |
+
img_offs.append(offsets[img2])
|
| 760 |
+
cur_n.append(cur_n[-1] + len(pixels))
|
| 761 |
+
# store the position of 3d points
|
| 762 |
+
tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
|
| 763 |
+
img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
|
| 764 |
+
|
| 765 |
+
all_confs = []
|
| 766 |
+
imgs_slices = []
|
| 767 |
+
corres2d = {img: [] for img in range(len(imgs))}
|
| 768 |
+
|
| 769 |
+
for img1, img2 in tmp_paths:
|
| 770 |
+
try:
|
| 771 |
+
pix1, confs1, slice1 = tmp_pixels[img1, img2]
|
| 772 |
+
pix2, confs2, slice2 = tmp_pixels[img2, img1]
|
| 773 |
+
except KeyError:
|
| 774 |
+
continue
|
| 775 |
+
img1 = imgs.index(img1)
|
| 776 |
+
img2 = imgs.index(img2)
|
| 777 |
+
confs = (confs1 * confs2).sqrt()
|
| 778 |
+
|
| 779 |
+
# prepare for loss_3d
|
| 780 |
+
all_confs.append(confs)
|
| 781 |
+
anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
|
| 782 |
+
anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
|
| 783 |
+
imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
|
| 784 |
+
img2, slice2, pix2, anchor_idxs2,
|
| 785 |
+
confs, float(confs.sum())))
|
| 786 |
+
|
| 787 |
+
# prepare for loss_2d
|
| 788 |
+
corres2d[img1].append((pix1, confs, img2, slice2))
|
| 789 |
+
corres2d[img2].append((pix2, confs, img1, slice1))
|
| 790 |
+
|
| 791 |
+
all_confs = torch.cat(all_confs)
|
| 792 |
+
corres = (all_confs, float(all_confs.sum()), imgs_slices)
|
| 793 |
+
|
| 794 |
+
def aggreg_matches(img1, list_matches):
|
| 795 |
+
pix1, confs, img2, slice2 = zip(*list_matches)
|
| 796 |
+
all_pix1 = torch.cat(pix1).to(dtype)
|
| 797 |
+
all_confs = torch.cat(confs).to(dtype)
|
| 798 |
+
return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
|
| 799 |
+
corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
|
| 800 |
+
|
| 801 |
+
imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
|
| 802 |
+
principal_points = torch.stack(principal_points)
|
| 803 |
+
focals = torch.cat(focals)
|
| 804 |
+
|
| 805 |
+
# Subsample preds_21
|
| 806 |
+
subsamp_preds_21 = {}
|
| 807 |
+
for imk, imv in preds_21.items():
|
| 808 |
+
subsamp_preds_21[imk] = {}
|
| 809 |
+
for im2k, (pred, conf) in preds_21[imk].items():
|
| 810 |
+
idxs = img_anchors[imgs.index(im2k)][1]
|
| 811 |
+
subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
|
| 812 |
+
|
| 813 |
+
return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d, subsamp_preds_21
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
|
| 817 |
+
assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
|
| 818 |
+
|
| 819 |
+
# canonical pointmap is just a weighted average
|
| 820 |
+
confs11 = confs11.unsqueeze(-1) - 0.999
|
| 821 |
+
canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
|
| 822 |
+
|
| 823 |
+
canon_depth = ptmaps11[..., 2].unsqueeze(1)
|
| 824 |
+
S = slice(subsample // 2, None, subsample)
|
| 825 |
+
center_depth = canon_depth[:, :, S, S]
|
| 826 |
+
center_depth = torch.clip(center_depth, min=torch.finfo(center_depth.dtype).eps)
|
| 827 |
+
|
| 828 |
+
stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
|
| 829 |
+
stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
|
| 830 |
+
|
| 831 |
+
if mode == 'avg-reldepth':
|
| 832 |
+
rel_depth = stacked_depth / center_depth
|
| 833 |
+
stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
|
| 834 |
+
canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
|
| 835 |
+
|
| 836 |
+
elif mode == 'avg-angle':
|
| 837 |
+
xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
|
| 838 |
+
stacked_xy = F.pixel_unshuffle(xy, subsample)
|
| 839 |
+
B, _, H, W = stacked_xy.shape
|
| 840 |
+
stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
|
| 841 |
+
stacked_radius.clip_(min=1e-8)
|
| 842 |
+
|
| 843 |
+
stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
|
| 844 |
+
avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
|
| 845 |
+
|
| 846 |
+
# back to depth
|
| 847 |
+
stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
|
| 848 |
+
|
| 849 |
+
canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
|
| 850 |
+
else:
|
| 851 |
+
raise ValueError(f'bad {mode=}')
|
| 852 |
+
|
| 853 |
+
confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
|
| 854 |
+
return canon, canon2, confs
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def anchor_depth_offsets(canon_depth, pixels, subsample=8):
|
| 858 |
+
device = canon_depth.device
|
| 859 |
+
|
| 860 |
+
# create a 2D grid of anchor 3D points
|
| 861 |
+
H1, W1 = canon_depth.shape
|
| 862 |
+
yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
|
| 863 |
+
H2, W2 = yx.shape[1:]
|
| 864 |
+
cy, cx = yx.reshape(2, -1)
|
| 865 |
+
core_depth = canon_depth[cy, cx]
|
| 866 |
+
assert (core_depth > 0).all()
|
| 867 |
+
|
| 868 |
+
# slave 3d points (attached to core 3d points)
|
| 869 |
+
core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
|
| 870 |
+
core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
|
| 871 |
+
|
| 872 |
+
for img2, (xy1, _confs) in pixels.items():
|
| 873 |
+
px, py = xy1.long().T
|
| 874 |
+
|
| 875 |
+
# find nearest anchor == block quantization
|
| 876 |
+
core_idx = (py // subsample) * W2 + (px // subsample)
|
| 877 |
+
core_idxs[img2] = core_idx.to(device)
|
| 878 |
+
|
| 879 |
+
# compute relative depth offsets w.r.t. anchors
|
| 880 |
+
ref_z = core_depth[core_idx]
|
| 881 |
+
pts_z = canon_depth[py, px]
|
| 882 |
+
offset = pts_z / ref_z
|
| 883 |
+
core_offs[img2] = offset.detach().to(device)
|
| 884 |
+
|
| 885 |
+
return core_idxs, core_offs
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def spectral_clustering(graph, k=None, normalized_cuts=False):
|
| 889 |
+
graph.fill_diagonal_(0)
|
| 890 |
+
|
| 891 |
+
# graph laplacian
|
| 892 |
+
degrees = graph.sum(dim=-1)
|
| 893 |
+
laplacian = torch.diag(degrees) - graph
|
| 894 |
+
if normalized_cuts:
|
| 895 |
+
i_inv = torch.diag(degrees.sqrt().reciprocal())
|
| 896 |
+
laplacian = i_inv @ laplacian @ i_inv
|
| 897 |
+
|
| 898 |
+
# compute eigenvectors!
|
| 899 |
+
eigval, eigvec = torch.linalg.eigh(laplacian)
|
| 900 |
+
return eigval[:k], eigvec[:, :k]
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def sim_func(p1, p2, gamma):
|
| 904 |
+
diff = (p1 - p2).norm(dim=-1)
|
| 905 |
+
avg_depth = (p1[:, :, 2] + p2[:, :, 2])
|
| 906 |
+
rel_distance = diff / avg_depth
|
| 907 |
+
sim = torch.exp(-gamma * rel_distance.square())
|
| 908 |
+
return sim
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def backproj(K, depthmap, subsample):
|
| 912 |
+
H, W = depthmap.shape
|
| 913 |
+
uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
|
| 914 |
+
xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
|
| 915 |
+
return xyz
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
|
| 919 |
+
normalized_cuts=True, gamma=7, min_norm=5):
|
| 920 |
+
try:
|
| 921 |
+
if cache_path:
|
| 922 |
+
cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
|
| 923 |
+
lora_proj = torch.load(cache_path, map_location=K.device)
|
| 924 |
+
|
| 925 |
+
except IOError:
|
| 926 |
+
# reconstruct 3d points in camera coordinates
|
| 927 |
+
xyz = backproj(K, depthmap, subsample)
|
| 928 |
+
|
| 929 |
+
# compute all distances
|
| 930 |
+
xyz = xyz.reshape(-1, 3)
|
| 931 |
+
graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
|
| 932 |
+
_, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
|
| 933 |
+
|
| 934 |
+
if cache_path:
|
| 935 |
+
torch.save(lora_proj.cpu(), mkdir_for(cache_path))
|
| 936 |
+
|
| 937 |
+
lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
|
| 938 |
+
|
| 939 |
+
# depthmap ~= lora_proj @ coeffs
|
| 940 |
+
return coeffs, lora_proj
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
|
| 944 |
+
# encode the pointmap
|
| 945 |
+
coeffs = torch.linalg.pinv(lora_proj) @ x
|
| 946 |
+
|
| 947 |
+
# rectify the norm of basis vector to be ~ equal
|
| 948 |
+
if coeffs.ndim == 1:
|
| 949 |
+
coeffs = coeffs[:, None]
|
| 950 |
+
if global_norm:
|
| 951 |
+
lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
|
| 952 |
+
elif min_norm:
|
| 953 |
+
lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
|
| 954 |
+
# can have rounding errors here!
|
| 955 |
+
coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
|
| 956 |
+
|
| 957 |
+
return lora_proj.detach(), coeffs.detach()
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
@torch.no_grad()
|
| 961 |
+
def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
|
| 962 |
+
# recover 3d points
|
| 963 |
+
core_depth = []
|
| 964 |
+
lora_proj = []
|
| 965 |
+
|
| 966 |
+
for i, img in enumerate(tqdm(imgs)):
|
| 967 |
+
cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
|
| 968 |
+
depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
|
| 969 |
+
cache_path=cache, **kw)
|
| 970 |
+
core_depth.append(depth)
|
| 971 |
+
lora_proj.append(proj)
|
| 972 |
+
|
| 973 |
+
return core_depth, lora_proj
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def reproj2d(Trf, pts3d):
|
| 977 |
+
res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
|
| 978 |
+
clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
|
| 979 |
+
uv = res[:, 0:2] / clipped_z
|
| 980 |
+
return uv.clip(min=-1000, max=2000)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
def bfs(tree, start_node):
|
| 984 |
+
order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
|
| 985 |
+
ranks = np.arange(len(order))
|
| 986 |
+
ranks[order] = ranks.copy()
|
| 987 |
+
return ranks, predecessors
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
def compute_min_spanning_tree(pws):
|
| 991 |
+
sparse_graph = sp.dok_array(pws.shape)
|
| 992 |
+
for i, j in pws.nonzero().cpu().tolist():
|
| 993 |
+
sparse_graph[i, j] = -float(pws[i, j])
|
| 994 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
|
| 995 |
+
|
| 996 |
+
# now reorder the oriented edges, starting from the central point
|
| 997 |
+
ranks1, _ = bfs(msp, 0)
|
| 998 |
+
ranks2, _ = bfs(msp, ranks1.argmax())
|
| 999 |
+
ranks1, _ = bfs(msp, ranks2.argmax())
|
| 1000 |
+
# this is the point farther from any leaf
|
| 1001 |
+
root = np.minimum(ranks1, ranks2).argmax()
|
| 1002 |
+
|
| 1003 |
+
# find the ordered list of edges that describe the tree
|
| 1004 |
+
order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
|
| 1005 |
+
order = order[1:] # root not do not have a predecessor
|
| 1006 |
+
edges = [(predecessors[i], i) for i in order]
|
| 1007 |
+
|
| 1008 |
+
return root, edges
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
|
| 1012 |
+
viz = SceneViz()
|
| 1013 |
+
|
| 1014 |
+
cc = cam2w[:, :3, 3]
|
| 1015 |
+
cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
|
| 1016 |
+
colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
|
| 1017 |
+
|
| 1018 |
+
if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
|
| 1019 |
+
cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
|
| 1020 |
+
else:
|
| 1021 |
+
imgs = shapes_or_imgs
|
| 1022 |
+
cam_kws = dict(images=imgs, cam_size=cs)
|
| 1023 |
+
if K is not None:
|
| 1024 |
+
viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
|
| 1025 |
+
|
| 1026 |
+
if gt_cam2w is not None:
|
| 1027 |
+
if gt_K is None:
|
| 1028 |
+
gt_K = K
|
| 1029 |
+
viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
|
| 1030 |
+
|
| 1031 |
+
if pts3d is not None:
|
| 1032 |
+
for i, p in enumerate(pts3d):
|
| 1033 |
+
if not len(p):
|
| 1034 |
+
continue
|
| 1035 |
+
if masks is None:
|
| 1036 |
+
viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
|
| 1037 |
+
else:
|
| 1038 |
+
viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
|
| 1039 |
+
viz.show(**kw)
|