|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import PIL.Image |
|
|
import numpy as np |
|
|
from scipy.spatial.transform import Rotation |
|
|
import torch |
|
|
import cv2 |
|
|
import matplotlib as mpl |
|
|
import matplotlib.cm as cm |
|
|
import matplotlib.pyplot as plt |
|
|
from dust3r.utils.geometry import ( |
|
|
geotrf, |
|
|
get_med_dist_between_poses, |
|
|
depthmap_to_absolute_camera_coordinates, |
|
|
) |
|
|
from dust3r.utils.device import to_numpy |
|
|
from dust3r.utils.image import rgb, img_to_arr |
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
|
from matplotlib.figure import Figure |
|
|
|
|
|
try: |
|
|
import trimesh |
|
|
except ImportError: |
|
|
print("/!\\ module trimesh is not installed, cannot visualize results /!\\") |
|
|
|
|
|
|
|
|
def float2uint8(x): |
|
|
return (255.0 * x).astype(np.uint8) |
|
|
|
|
|
|
|
|
def uint82float(img): |
|
|
return np.ascontiguousarray(img) / 255.0 |
|
|
|
|
|
|
|
|
def cat_3d(vecs): |
|
|
if isinstance(vecs, (np.ndarray, torch.Tensor)): |
|
|
vecs = [vecs] |
|
|
return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) |
|
|
|
|
|
|
|
|
def show_raw_pointcloud(pts3d, colors, point_size=2): |
|
|
scene = trimesh.Scene() |
|
|
|
|
|
pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) |
|
|
scene.add_geometry(pct) |
|
|
|
|
|
scene.show(line_settings={"point_size": point_size}) |
|
|
|
|
|
|
|
|
def pts3d_to_trimesh(img, pts3d, valid=None): |
|
|
H, W, THREE = img.shape |
|
|
assert THREE == 3 |
|
|
assert img.shape == pts3d.shape |
|
|
|
|
|
vertices = pts3d.reshape(-1, 3) |
|
|
|
|
|
idx = np.arange(len(vertices)).reshape(H, W) |
|
|
idx1 = idx[:-1, :-1].ravel() |
|
|
idx2 = idx[:-1, +1:].ravel() |
|
|
idx3 = idx[+1:, :-1].ravel() |
|
|
idx4 = idx[+1:, +1:].ravel() |
|
|
faces = np.concatenate( |
|
|
( |
|
|
np.c_[idx1, idx2, idx3], |
|
|
np.c_[ |
|
|
idx3, idx2, idx1 |
|
|
], |
|
|
np.c_[idx2, idx3, idx4], |
|
|
np.c_[ |
|
|
idx4, idx3, idx2 |
|
|
], |
|
|
), |
|
|
axis=0, |
|
|
) |
|
|
|
|
|
face_colors = np.concatenate( |
|
|
( |
|
|
img[:-1, :-1].reshape(-1, 3), |
|
|
img[:-1, :-1].reshape(-1, 3), |
|
|
img[+1:, +1:].reshape(-1, 3), |
|
|
img[+1:, +1:].reshape(-1, 3), |
|
|
), |
|
|
axis=0, |
|
|
) |
|
|
|
|
|
if valid is not None: |
|
|
assert valid.shape == (H, W) |
|
|
valid_idxs = valid.ravel() |
|
|
valid_faces = valid_idxs[faces].all(axis=-1) |
|
|
faces = faces[valid_faces] |
|
|
face_colors = face_colors[valid_faces] |
|
|
|
|
|
assert len(faces) == len(face_colors) |
|
|
return dict(vertices=vertices, face_colors=face_colors, faces=faces) |
|
|
|
|
|
|
|
|
def cat_meshes(meshes): |
|
|
vertices, faces, colors = zip( |
|
|
*[(m["vertices"], m["faces"], m["face_colors"]) for m in meshes] |
|
|
) |
|
|
n_vertices = np.cumsum([0] + [len(v) for v in vertices]) |
|
|
for i in range(len(faces)): |
|
|
faces[i][:] += n_vertices[i] |
|
|
|
|
|
vertices = np.concatenate(vertices) |
|
|
colors = np.concatenate(colors) |
|
|
faces = np.concatenate(faces) |
|
|
return dict(vertices=vertices, face_colors=colors, faces=faces) |
|
|
|
|
|
|
|
|
def show_duster_pairs(view1, view2, pred1, pred2): |
|
|
import matplotlib.pyplot as pl |
|
|
|
|
|
pl.ion() |
|
|
|
|
|
for e in range(len(view1["instance"])): |
|
|
i = view1["idx"][e] |
|
|
j = view2["idx"][e] |
|
|
img1 = rgb(view1["img"][e]) |
|
|
img2 = rgb(view2["img"][e]) |
|
|
conf1 = pred1["conf"][e].squeeze() |
|
|
conf2 = pred2["conf"][e].squeeze() |
|
|
score = conf1.mean() * conf2.mean() |
|
|
print(f">> Showing pair #{e} {i}-{j} {score=:g}") |
|
|
pl.clf() |
|
|
pl.subplot(221).imshow(img1) |
|
|
pl.subplot(223).imshow(img2) |
|
|
pl.subplot(222).imshow(conf1, vmin=1, vmax=30) |
|
|
pl.subplot(224).imshow(conf2, vmin=1, vmax=30) |
|
|
pts1 = pred1["pts3d"][e] |
|
|
pts2 = pred2["pts3d_in_other_view"][e] |
|
|
pl.subplots_adjust(0, 0, 1, 1, 0, 0) |
|
|
if input("show pointcloud? (y/n) ") == "y": |
|
|
show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) |
|
|
|
|
|
|
|
|
def auto_cam_size(im_poses): |
|
|
return 0.1 * get_med_dist_between_poses(im_poses) |
|
|
|
|
|
|
|
|
class SceneViz: |
|
|
def __init__(self): |
|
|
self.scene = trimesh.Scene() |
|
|
|
|
|
def add_rgbd( |
|
|
self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None |
|
|
): |
|
|
image = img_to_arr(image) |
|
|
|
|
|
if intrinsics is None: |
|
|
H, W, THREE = image.shape |
|
|
focal = max(H, W) |
|
|
intrinsics = np.float32([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]]) |
|
|
|
|
|
pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world) |
|
|
|
|
|
return self.add_pointcloud( |
|
|
pts3d, image, mask=(depth < zfar) if mask is None else mask |
|
|
) |
|
|
|
|
|
def add_pointcloud(self, pts3d, color=(0, 0, 0), mask=None, denoise=False): |
|
|
pts3d = to_numpy(pts3d) |
|
|
mask = to_numpy(mask) |
|
|
if not isinstance(pts3d, list): |
|
|
pts3d = [pts3d.reshape(-1, 3)] |
|
|
if mask is not None: |
|
|
mask = [mask.ravel()] |
|
|
if not isinstance(color, (tuple, list)): |
|
|
color = [color.reshape(-1, 3)] |
|
|
if mask is None: |
|
|
mask = [slice(None)] * len(pts3d) |
|
|
|
|
|
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) |
|
|
pct = trimesh.PointCloud(pts) |
|
|
|
|
|
if isinstance(color, (list, np.ndarray, torch.Tensor)): |
|
|
color = to_numpy(color) |
|
|
col = np.concatenate([p[m] for p, m in zip(color, mask)]) |
|
|
assert col.shape == pts.shape, bb() |
|
|
pct.visual.vertex_colors = uint8(col.reshape(-1, 3)) |
|
|
else: |
|
|
assert len(color) == 3 |
|
|
pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape) |
|
|
|
|
|
if denoise: |
|
|
|
|
|
centroid = np.median(pct.vertices, axis=0) |
|
|
dist_to_centroid = np.linalg.norm(pct.vertices - centroid, axis=-1) |
|
|
dist_thr = np.quantile(dist_to_centroid, 0.99) |
|
|
valid = dist_to_centroid < dist_thr |
|
|
|
|
|
pct = trimesh.PointCloud( |
|
|
pct.vertices[valid], color=pct.visual.vertex_colors[valid] |
|
|
) |
|
|
|
|
|
self.scene.add_geometry(pct) |
|
|
return self |
|
|
|
|
|
def add_rgbd( |
|
|
self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None |
|
|
): |
|
|
|
|
|
if intrinsics is None: |
|
|
H, W, THREE = image.shape |
|
|
focal = max(H, W) |
|
|
intrinsics = np.float32([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]]) |
|
|
|
|
|
pts3d, mask2 = depthmap_to_absolute_camera_coordinates( |
|
|
depth, intrinsics, cam2world |
|
|
) |
|
|
mask2 &= depth < zfar |
|
|
|
|
|
if mask is not None: |
|
|
mask2 &= mask |
|
|
|
|
|
return self.add_pointcloud(pts3d, image, mask=mask2) |
|
|
|
|
|
def add_camera( |
|
|
self, |
|
|
pose_c2w, |
|
|
focal=None, |
|
|
color=(0, 0, 0), |
|
|
image=None, |
|
|
imsize=None, |
|
|
cam_size=0.03, |
|
|
): |
|
|
pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image)) |
|
|
image = img_to_arr(image) |
|
|
if isinstance(focal, np.ndarray) and focal.shape == (3, 3): |
|
|
intrinsics = focal |
|
|
focal = (intrinsics[0, 0] * intrinsics[1, 1]) ** 0.5 |
|
|
if imsize is None: |
|
|
imsize = (2 * intrinsics[0, 2], 2 * intrinsics[1, 2]) |
|
|
|
|
|
add_scene_cam( |
|
|
self.scene, |
|
|
pose_c2w, |
|
|
color, |
|
|
image, |
|
|
focal, |
|
|
imsize=imsize, |
|
|
screen_width=cam_size, |
|
|
marker=None, |
|
|
) |
|
|
return self |
|
|
|
|
|
def add_cameras( |
|
|
self, poses, focals=None, images=None, imsizes=None, colors=None, **kw |
|
|
): |
|
|
get = lambda arr, idx: None if arr is None else arr[idx] |
|
|
for i, pose_c2w in enumerate(poses): |
|
|
self.add_camera( |
|
|
pose_c2w, |
|
|
get(focals, i), |
|
|
image=get(images, i), |
|
|
color=get(colors, i), |
|
|
imsize=get(imsizes, i), |
|
|
**kw, |
|
|
) |
|
|
return self |
|
|
|
|
|
def show(self, point_size=2): |
|
|
self.scene.show(line_settings={"point_size": point_size}) |
|
|
|
|
|
|
|
|
def show_raw_pointcloud_with_cams( |
|
|
imgs, pts3d, mask, focals, cams2world, point_size=2, cam_size=0.05, cam_color=None |
|
|
): |
|
|
"""Visualization of a pointcloud with cameras |
|
|
imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...] |
|
|
pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...] |
|
|
focals = (N,) or N-size list of [focal, ...] |
|
|
cams2world = (N,4,4) or N-size list of [(4,4), ...] |
|
|
""" |
|
|
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) |
|
|
pts3d = to_numpy(pts3d) |
|
|
imgs = to_numpy(imgs) |
|
|
focals = to_numpy(focals) |
|
|
cams2world = to_numpy(cams2world) |
|
|
|
|
|
scene = trimesh.Scene() |
|
|
|
|
|
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) |
|
|
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) |
|
|
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) |
|
|
scene.add_geometry(pct) |
|
|
|
|
|
for i, pose_c2w in enumerate(cams2world): |
|
|
if isinstance(cam_color, list): |
|
|
camera_edge_color = cam_color[i] |
|
|
else: |
|
|
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] |
|
|
add_scene_cam( |
|
|
scene, |
|
|
pose_c2w, |
|
|
camera_edge_color, |
|
|
imgs[i] if i < len(imgs) else None, |
|
|
focals[i], |
|
|
screen_width=cam_size, |
|
|
) |
|
|
|
|
|
scene.show(line_settings={"point_size": point_size}) |
|
|
|
|
|
|
|
|
def add_scene_cam( |
|
|
scene, |
|
|
pose_c2w, |
|
|
edge_color, |
|
|
image=None, |
|
|
focal=None, |
|
|
imsize=None, |
|
|
screen_width=0.03, |
|
|
marker=None, |
|
|
): |
|
|
if image is not None: |
|
|
image = np.asarray(image) |
|
|
H, W, THREE = image.shape |
|
|
assert THREE == 3 |
|
|
if image.dtype != np.uint8: |
|
|
image = np.uint8(255 * image) |
|
|
elif imsize is not None: |
|
|
W, H = imsize |
|
|
elif focal is not None: |
|
|
H = W = focal / 1.1 |
|
|
else: |
|
|
H = W = 1 |
|
|
|
|
|
if isinstance(focal, np.ndarray): |
|
|
focal = focal[0] |
|
|
if not focal: |
|
|
focal = min(H, W) * 1.1 |
|
|
|
|
|
height = max(screen_width / 10, focal * screen_width / H) |
|
|
width = screen_width * 0.5**0.5 |
|
|
rot45 = np.eye(4) |
|
|
rot45[:3, :3] = Rotation.from_euler("z", np.deg2rad(45)).as_matrix() |
|
|
rot45[2, 3] = -height |
|
|
aspect_ratio = np.eye(4) |
|
|
aspect_ratio[0, 0] = W / H |
|
|
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 |
|
|
cam = trimesh.creation.cone(width, height, sections=4) |
|
|
|
|
|
if image is not None: |
|
|
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) |
|
|
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) |
|
|
img = trimesh.Trimesh(vertices=vertices, faces=faces) |
|
|
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) |
|
|
img.visual = trimesh.visual.TextureVisuals( |
|
|
uv_coords, image=PIL.Image.fromarray(image) |
|
|
) |
|
|
scene.add_geometry(img) |
|
|
|
|
|
rot2 = np.eye(4) |
|
|
rot2[:3, :3] = Rotation.from_euler("z", np.deg2rad(2)).as_matrix() |
|
|
vertices = np.r_[cam.vertices, 0.95 * cam.vertices, geotrf(rot2, cam.vertices)] |
|
|
vertices = geotrf(transform, vertices) |
|
|
faces = [] |
|
|
for face in cam.faces: |
|
|
if 0 in face: |
|
|
continue |
|
|
a, b, c = face |
|
|
a2, b2, c2 = face + len(cam.vertices) |
|
|
a3, b3, c3 = face + 2 * len(cam.vertices) |
|
|
|
|
|
faces.append((a, b, b2)) |
|
|
faces.append((a, a2, c)) |
|
|
faces.append((c2, b, c)) |
|
|
|
|
|
faces.append((a, b, b3)) |
|
|
faces.append((a, a3, c)) |
|
|
faces.append((c3, b, c)) |
|
|
|
|
|
faces += [(c, b, a) for a, b, c in faces] |
|
|
|
|
|
cam = trimesh.Trimesh(vertices=vertices, faces=faces) |
|
|
cam.visual.face_colors[:, :3] = edge_color |
|
|
scene.add_geometry(cam) |
|
|
|
|
|
if marker == "o": |
|
|
marker = trimesh.creation.icosphere(3, radius=screen_width / 4) |
|
|
marker.vertices += pose_c2w[:3, 3] |
|
|
marker.visual.face_colors[:, :3] = edge_color |
|
|
scene.add_geometry(marker) |
|
|
|
|
|
|
|
|
def cat(a, b): |
|
|
return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3))) |
|
|
|
|
|
|
|
|
OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) |
|
|
|
|
|
|
|
|
CAM_COLORS = [ |
|
|
(255, 0, 0), |
|
|
(0, 0, 255), |
|
|
(0, 255, 0), |
|
|
(255, 0, 255), |
|
|
(255, 204, 0), |
|
|
(0, 204, 204), |
|
|
(128, 255, 255), |
|
|
(255, 128, 255), |
|
|
(255, 255, 128), |
|
|
(0, 0, 0), |
|
|
(128, 128, 128), |
|
|
] |
|
|
|
|
|
|
|
|
def uint8(colors): |
|
|
if not isinstance(colors, np.ndarray): |
|
|
colors = np.array(colors) |
|
|
if np.issubdtype(colors.dtype, np.floating): |
|
|
colors *= 255 |
|
|
assert 0 <= colors.min() and colors.max() < 256 |
|
|
return np.uint8(colors) |
|
|
|
|
|
|
|
|
def segment_sky(image): |
|
|
import cv2 |
|
|
from scipy import ndimage |
|
|
|
|
|
image = to_numpy(image) |
|
|
if np.issubdtype(image.dtype, np.floating): |
|
|
image = np.uint8(255 * image.clip(min=0, max=1)) |
|
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|
|
|
|
|
lower_blue = np.array([0, 0, 100]) |
|
|
upper_blue = np.array([30, 255, 255]) |
|
|
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool) |
|
|
|
|
|
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150) |
|
|
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) |
|
|
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) |
|
|
|
|
|
kernel = np.ones((5, 5), np.uint8) |
|
|
mask2 = ndimage.binary_opening(mask, structure=kernel) |
|
|
|
|
|
_, labels, stats, _ = cv2.connectedComponentsWithStats( |
|
|
mask2.view(np.uint8), connectivity=8 |
|
|
) |
|
|
cc_sizes = stats[1:, cv2.CC_STAT_AREA] |
|
|
order = cc_sizes.argsort()[::-1] |
|
|
i = 0 |
|
|
selection = [] |
|
|
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: |
|
|
selection.append(1 + order[i]) |
|
|
i += 1 |
|
|
mask3 = np.in1d(labels, selection).reshape(labels.shape) |
|
|
|
|
|
return torch.from_numpy(mask3) |
|
|
|
|
|
|
|
|
def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2): |
|
|
""" |
|
|
:param w: pixels |
|
|
:param h: pixels |
|
|
:param vmin: min value |
|
|
:param vmax: max value |
|
|
:param cmap_name: |
|
|
:param label |
|
|
:return: |
|
|
""" |
|
|
fig = Figure(figsize=(2, 8), dpi=100) |
|
|
fig.subplots_adjust(right=1.5) |
|
|
canvas = FigureCanvasAgg(fig) |
|
|
|
|
|
ax = fig.add_subplot(111) |
|
|
cmap = cm.get_cmap(cmap_name) |
|
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) |
|
|
|
|
|
tick_cnt = 6 |
|
|
tick_loc = np.linspace(vmin, vmax, tick_cnt) |
|
|
cb1 = mpl.colorbar.ColorbarBase( |
|
|
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical" |
|
|
) |
|
|
|
|
|
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] |
|
|
if cbar_precision == 0: |
|
|
tick_label = [x[:-2] for x in tick_label] |
|
|
|
|
|
cb1.set_ticklabels(tick_label) |
|
|
|
|
|
cb1.ax.tick_params(labelsize=18, rotation=0) |
|
|
if label is not None: |
|
|
cb1.set_label(label) |
|
|
|
|
|
fig.tight_layout() |
|
|
|
|
|
canvas.draw() |
|
|
s, (width, height) = canvas.print_to_buffer() |
|
|
|
|
|
im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) |
|
|
|
|
|
im = im[:, :, :3].astype(np.float32) / 255.0 |
|
|
if h != im.shape[0]: |
|
|
w = int(im.shape[1] / im.shape[0] * h) |
|
|
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
return im |
|
|
|
|
|
|
|
|
def colorize_np( |
|
|
x, |
|
|
cmap_name="jet", |
|
|
mask=None, |
|
|
range=None, |
|
|
append_cbar=False, |
|
|
cbar_in_image=False, |
|
|
cbar_precision=2, |
|
|
): |
|
|
""" |
|
|
turn a grayscale image into a color image |
|
|
:param x: input grayscale, [H, W] |
|
|
:param cmap_name: the colorization method |
|
|
:param mask: the mask image, [H, W] |
|
|
:param range: the range for scaling, automatic if None, [min, max] |
|
|
:param append_cbar: if append the color bar |
|
|
:param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image |
|
|
:return: colorized image, [H, W] |
|
|
""" |
|
|
if range is not None: |
|
|
vmin, vmax = range |
|
|
elif mask is not None: |
|
|
|
|
|
vmin = np.min(x[mask][np.nonzero(x[mask])]) |
|
|
vmax = np.max(x[mask]) |
|
|
|
|
|
x[np.logical_not(mask)] = vmin |
|
|
|
|
|
else: |
|
|
vmin, vmax = np.percentile(x, (1, 100)) |
|
|
vmax += 1e-6 |
|
|
|
|
|
x = np.clip(x, vmin, vmax) |
|
|
x = (x - vmin) / (vmax - vmin) |
|
|
|
|
|
cmap = cm.get_cmap(cmap_name) |
|
|
x_new = cmap(x)[:, :, :3] |
|
|
|
|
|
if mask is not None: |
|
|
mask = np.float32(mask[:, :, np.newaxis]) |
|
|
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) |
|
|
|
|
|
cbar = get_vertical_colorbar( |
|
|
h=x.shape[0], |
|
|
vmin=vmin, |
|
|
vmax=vmax, |
|
|
cmap_name=cmap_name, |
|
|
cbar_precision=cbar_precision, |
|
|
) |
|
|
|
|
|
if append_cbar: |
|
|
if cbar_in_image: |
|
|
x_new[:, -cbar.shape[1] :, :] = cbar |
|
|
else: |
|
|
x_new = np.concatenate( |
|
|
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 |
|
|
) |
|
|
return x_new |
|
|
else: |
|
|
return x_new |
|
|
|
|
|
|
|
|
def colorize( |
|
|
x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False |
|
|
): |
|
|
""" |
|
|
turn a grayscale image into a color image |
|
|
:param x: torch.Tensor, grayscale image, [H, W] or [B, H, W] |
|
|
:param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None |
|
|
""" |
|
|
|
|
|
device = x.device |
|
|
x = x.cpu().numpy() |
|
|
if mask is not None: |
|
|
mask = mask.cpu().numpy() > 0.99 |
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
|
|
|
if x.ndim == 2: |
|
|
x = x[None] |
|
|
if mask is not None: |
|
|
mask = mask[None] |
|
|
|
|
|
out = [] |
|
|
for x_ in x: |
|
|
if mask is not None: |
|
|
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) |
|
|
|
|
|
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image) |
|
|
out.append(torch.from_numpy(x_).to(device).float()) |
|
|
out = torch.stack(out).squeeze(0) |
|
|
return out |
|
|
|
|
|
|
|
|
def draw_correspondences( |
|
|
imgs1, imgs2, coords1, coords2, interval=10, color_by=0, radius=2 |
|
|
): |
|
|
""" |
|
|
draw correspondences between two images |
|
|
:param img1: tensor [B, H, W, 3] |
|
|
:param img2: tensor [B, H, W, 3] |
|
|
:param coord1: tensor [B, N, 2] |
|
|
:param coord2: tensor [B, N, 2] |
|
|
:param interval: int the interval between two points |
|
|
:param color_by: specify the color based on image 1 or image 2, 0 or 1 |
|
|
:return: [B, 2*H, W, 3] |
|
|
""" |
|
|
batch_size = len(imgs1) |
|
|
out = [] |
|
|
for i in range(batch_size): |
|
|
img1 = imgs1[i].detach().cpu().numpy() |
|
|
img2 = imgs2[i].detach().cpu().numpy() |
|
|
coord1 = ( |
|
|
coords1[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2) |
|
|
) |
|
|
coord2 = ( |
|
|
coords2[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2) |
|
|
) |
|
|
img = drawMatches( |
|
|
img1, img2, coord1, coord2, radius=radius, color_by=color_by, row_cat=True |
|
|
) |
|
|
out.append(img) |
|
|
out = np.stack(out) |
|
|
return out |
|
|
|
|
|
|
|
|
def draw_correspondences_lines( |
|
|
imgs1, imgs2, coords1, coords2, interval=10, color_by=0, radius=2 |
|
|
): |
|
|
""" |
|
|
draw correspondences between two images |
|
|
:param img1: tensor [B, H, W, 3] |
|
|
:param img2: tensor [B, H, W, 3] |
|
|
:param coord1: tensor [B, N, 2] |
|
|
:param coord2: tensor [B, N, 2] |
|
|
:param interval: int the interval between two points |
|
|
:param color_by: specify the color based on image 1 or image 2, 0 or 1 |
|
|
:return: [B, 2*H, W, 3] |
|
|
""" |
|
|
batch_size = len(imgs1) |
|
|
out = [] |
|
|
for i in range(batch_size): |
|
|
img1 = imgs1[i].detach().cpu().numpy() |
|
|
img2 = imgs2[i].detach().cpu().numpy() |
|
|
coord1 = ( |
|
|
coords1[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2) |
|
|
) |
|
|
coord2 = ( |
|
|
coords2[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2) |
|
|
) |
|
|
img = drawMatches_lines( |
|
|
img1, img2, coord1, coord2, radius=radius, color_by=color_by, row_cat=True |
|
|
) |
|
|
out.append(img) |
|
|
out = np.stack(out) |
|
|
return out |
|
|
|
|
|
|
|
|
def drawMatches(img1, img2, kp1, kp2, radius=2, mask=None, color_by=0, row_cat=False): |
|
|
|
|
|
h1, w1 = img1.shape[:2] |
|
|
h2, w2 = img2.shape[:2] |
|
|
|
|
|
img1 = np.ascontiguousarray(float2uint8(img1)) |
|
|
img2 = np.ascontiguousarray(float2uint8(img2)) |
|
|
|
|
|
center1 = np.median(kp1, axis=0) |
|
|
center2 = np.median(kp2, axis=0) |
|
|
|
|
|
set_max = range(128) |
|
|
colors = {m: i for i, m in enumerate(set_max)} |
|
|
colors = { |
|
|
m: (255 * np.array(plt.cm.hsv(i / float(len(colors))))[:3][::-1]).astype( |
|
|
np.int32 |
|
|
) |
|
|
for m, i in colors.items() |
|
|
} |
|
|
|
|
|
if mask is not None: |
|
|
ind = np.argsort(mask)[::-1] |
|
|
kp1 = kp1[ind] |
|
|
kp2 = kp2[ind] |
|
|
mask = mask[ind] |
|
|
|
|
|
for i, (pt1, pt2) in enumerate(zip(kp1, kp2)): |
|
|
|
|
|
if color_by == 0: |
|
|
coord_angle = np.arctan2(pt1[1] - center1[1], pt1[0] - center1[0]) |
|
|
elif color_by == 1: |
|
|
coord_angle = np.arctan2(pt2[1] - center2[1], pt2[0] - center2[0]) |
|
|
|
|
|
corr_color = np.int32(64 * coord_angle / np.pi) % 128 |
|
|
color = tuple(colors[corr_color].tolist()) |
|
|
|
|
|
if ( |
|
|
(pt1[0] <= w1 - 1) |
|
|
and (pt1[0] >= 0) |
|
|
and (pt1[1] <= h1 - 1) |
|
|
and (pt1[1] >= 0) |
|
|
): |
|
|
img1 = cv2.circle( |
|
|
img1, (int(pt1[0]), int(pt1[1])), radius, color, -1, cv2.LINE_AA |
|
|
) |
|
|
|
|
|
if ( |
|
|
(pt2[0] <= w2 - 1) |
|
|
and (pt2[0] >= 0) |
|
|
and (pt2[1] <= h2 - 1) |
|
|
and (pt2[1] >= 0) |
|
|
): |
|
|
if mask is not None and mask[i]: |
|
|
img2 = cv2.drawMarker( |
|
|
img2, |
|
|
(int(pt2[0]), int(pt2[1])), |
|
|
color, |
|
|
markerType=cv2.MARKER_CROSS, |
|
|
markerSize=int(5 * radius), |
|
|
thickness=int(radius / 2), |
|
|
line_type=cv2.LINE_AA, |
|
|
) |
|
|
else: |
|
|
img2 = cv2.circle( |
|
|
img2, (int(pt2[0]), int(pt2[1])), radius, color, -1, cv2.LINE_AA |
|
|
) |
|
|
if row_cat: |
|
|
whole_img = np.concatenate([img1, img2], axis=0) |
|
|
else: |
|
|
whole_img = np.concatenate([img1, img2], axis=1) |
|
|
return whole_img |
|
|
if row_cat: |
|
|
return np.concatenate([img1, img2], axis=0) |
|
|
return np.concatenate([img1, img2], axis=1) |
|
|
|
|
|
|
|
|
def drawMatches_lines( |
|
|
img1, img2, kp1, kp2, radius=2, mask=None, color_by=0, row_cat=False |
|
|
): |
|
|
|
|
|
h1, w1 = img1.shape[:2] |
|
|
h2, w2 = img2.shape[:2] |
|
|
|
|
|
img1 = np.ascontiguousarray(float2uint8(img1)) |
|
|
img2 = np.ascontiguousarray(float2uint8(img2)) |
|
|
|
|
|
center1 = np.median(kp1, axis=0) |
|
|
center2 = np.median(kp2, axis=0) |
|
|
|
|
|
set_max = range(128) |
|
|
colors = {m: i for i, m in enumerate(set_max)} |
|
|
colors = { |
|
|
m: (255 * np.array(plt.cm.hsv(i / float(len(colors))))[:3][::-1]).astype( |
|
|
np.int32 |
|
|
) |
|
|
for m, i in colors.items() |
|
|
} |
|
|
|
|
|
if mask is not None: |
|
|
ind = np.argsort(mask)[::-1] |
|
|
kp1 = kp1[ind] |
|
|
kp2 = kp2[ind] |
|
|
mask = mask[ind] |
|
|
|
|
|
if row_cat: |
|
|
whole_img = np.concatenate([img1, img2], axis=0) |
|
|
else: |
|
|
whole_img = np.concatenate([img1, img2], axis=1) |
|
|
for i, (pt1, pt2) in enumerate(zip(kp1, kp2)): |
|
|
if color_by == 0: |
|
|
coord_angle = np.arctan2(pt1[1] - center1[1], pt1[0] - center1[0]) |
|
|
elif color_by == 1: |
|
|
coord_angle = np.arctan2(pt2[1] - center2[1], pt2[0] - center2[0]) |
|
|
|
|
|
corr_color = np.int32(64 * coord_angle / np.pi) % 128 |
|
|
color = tuple(colors[corr_color].tolist()) |
|
|
rand_val = np.random.rand() |
|
|
if rand_val < 0.1: |
|
|
if ( |
|
|
(pt1[0] <= w1 - 1) |
|
|
and (pt1[0] >= 0) |
|
|
and (pt1[1] <= h1 - 1) |
|
|
and (pt1[1] >= 0) |
|
|
) and ( |
|
|
(pt2[0] <= w2 - 1) |
|
|
and (pt2[0] >= 0) |
|
|
and (pt2[1] <= h2 - 1) |
|
|
and (pt2[1] >= 0) |
|
|
): |
|
|
|
|
|
whole_img = cv2.circle( |
|
|
whole_img, |
|
|
(int(pt1[0]), int(pt1[1])), |
|
|
radius, |
|
|
color, |
|
|
-1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
|
|
|
if row_cat: |
|
|
whole_img = cv2.circle( |
|
|
whole_img, |
|
|
(int(pt2[0]), int(pt2[1] + h1)), |
|
|
radius, |
|
|
color, |
|
|
-1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
cv2.line( |
|
|
whole_img, |
|
|
(int(pt1[0]), int(pt1[1])), |
|
|
(int(pt2[0]), int(pt2[1] + h1)), |
|
|
color, |
|
|
1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
else: |
|
|
whole_img = cv2.circle( |
|
|
whole_img, |
|
|
(int(pt2[0] + w1), int(pt2[1])), |
|
|
radius, |
|
|
color, |
|
|
-1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
cv2.line( |
|
|
whole_img, |
|
|
(int(pt1[0]), int(pt1[1])), |
|
|
(int(pt2[0] + w1), int(pt2[1])), |
|
|
color, |
|
|
1, |
|
|
cv2.LINE_AA, |
|
|
) |
|
|
return whole_img |
|
|
if row_cat: |
|
|
return np.concatenate([img1, img2], axis=0) |
|
|
return np.concatenate([img1, img2], axis=1) |
|
|
|
|
|
|
|
|
import torch |
|
|
import os |
|
|
import time |
|
|
import viser |
|
|
|
|
|
|
|
|
def rotation_matrix_to_quaternion(R): |
|
|
""" |
|
|
:param R: [3, 3] |
|
|
:return: [4] |
|
|
""" |
|
|
tr = np.trace(R) |
|
|
Rxx = R[0, 0] |
|
|
Ryy = R[1, 1] |
|
|
Rzz = R[2, 2] |
|
|
q = np.zeros(4) |
|
|
q[0] = 0.5 * np.sqrt(1 + tr) |
|
|
q[1] = (R[2, 1] - R[1, 2]) / (4 * q[0]) |
|
|
q[2] = (R[0, 2] - R[2, 0]) / (4 * q[0]) |
|
|
q[3] = (R[1, 0] - R[0, 1]) / (4 * q[0]) |
|
|
return q |
|
|
|
|
|
|
|
|
class PointCloudViewer: |
|
|
def __init__(self, pc_dir, device="cpu"): |
|
|
self.server = viser.ViserServer() |
|
|
self.server.set_up_direction("-y") |
|
|
self.device = device |
|
|
self.tt = lambda x: torch.from_numpy(x).float().to(device) |
|
|
self.pc_dir = pc_dir |
|
|
self.pcs, self.all_steps = self.read_data() |
|
|
self.num_frames = len(self.all_steps) |
|
|
|
|
|
self.fix_camera = False |
|
|
self.camera_scale = self.server.add_gui_slider( |
|
|
"camera_scale", |
|
|
min=0.01, |
|
|
max=1.0, |
|
|
step=0.01, |
|
|
initial_value=0.1, |
|
|
) |
|
|
|
|
|
self.camera_handles = [] |
|
|
|
|
|
def read_data(self): |
|
|
pc_list = os.listdir(self.pc_dir) |
|
|
pc_list.sort(key=lambda x: int(x.split(".")[0].split("_")[-1])) |
|
|
pcs = {} |
|
|
step_list = [] |
|
|
for pc_name in pc_list: |
|
|
pc = np.load(os.path.join(self.pc_dir, pc_name)) |
|
|
step = int(pc_name.split(".")[0].split("_")[-1]) |
|
|
pcs.update({step: {"pc": pc}}) |
|
|
step_list.append(step) |
|
|
return pcs, step_list |
|
|
|
|
|
def parse_pc_data(self, pc, batch_idx=-1): |
|
|
idx = batch_idx |
|
|
ret_dict = {} |
|
|
for i in range(len(pc.keys()) // 2): |
|
|
pred_pts = pc[f"pts3d_{i+1}"][idx].reshape(-1, 3) |
|
|
color = pc[f"colors_{i+1}"][idx].reshape(-1, 3) |
|
|
ret_dict.update({f"pred_pts_{i+1}": pred_pts, f"color_{i+1}": color}) |
|
|
return ret_dict |
|
|
|
|
|
def add_pc(self, step): |
|
|
pc = self.pcs[step]["pc"] |
|
|
pc_dict = self.parse_pc_data(pc) |
|
|
|
|
|
for i in range(len(pc_dict.keys()) // 2): |
|
|
self.server.add_point_cloud( |
|
|
name=f"/frames/{step}/pred_pts_{i+1}_{step}", |
|
|
points=pc_dict[f"pred_pts_{i+1}"], |
|
|
colors=pc_dict[f"color_{i+1}"], |
|
|
point_size=0.002, |
|
|
) |
|
|
|
|
|
if not self.fix_camera: |
|
|
raise NotImplementedError |
|
|
|
|
|
R21, T21 = find_rigid_alignment_batched( |
|
|
torch.from_numpy(pc_dict["pred_pts1_2"][None]), |
|
|
torch.from_numpy(pc_dict["pred_pts1_1"][None]), |
|
|
) |
|
|
R12, T12 = find_rigid_alignment_batched( |
|
|
torch.from_numpy(pc_dict["pred_pts2_1"][None]), |
|
|
torch.from_numpy(pc_dict["pred_pts2_2"][None]), |
|
|
) |
|
|
R21 = R21[0].numpy() |
|
|
T21 = T21.numpy() |
|
|
R12 = R12[0].numpy() |
|
|
T12 = T12.numpy() |
|
|
pred_pts1_2 = pc_dict["pred_pts1_2"] @ R21.T + T21 |
|
|
pred_pts2_1 = pc_dict["pred_pts2_1"] @ R12.T + T12 |
|
|
self.server.add_point_cloud( |
|
|
name=f"/frames/{step}/pred_pts1_2_{step}", |
|
|
points=pred_pts1_2, |
|
|
colors=pc_dict["color1_2"], |
|
|
point_size=0.002, |
|
|
) |
|
|
|
|
|
self.server.add_point_cloud( |
|
|
name=f"/frames/{step}/pred_pts2_1_{step}", |
|
|
points=pred_pts2_1, |
|
|
colors=pc_dict["color2_1"], |
|
|
point_size=0.002, |
|
|
) |
|
|
img1 = pc_dict["color1_1"].reshape(224, 224, 3) |
|
|
img2 = pc_dict["color2_2"].reshape(224, 224, 3) |
|
|
self.camera_handles.append( |
|
|
self.server.add_camera_frustum( |
|
|
name=f"/frames/{step}/camera1_{step}", |
|
|
fov=2.0 * np.arctan(224.0 / 490.0), |
|
|
aspect=1.0, |
|
|
scale=self.camera_scale.value, |
|
|
color=(1.0, 0, 0), |
|
|
image=img1, |
|
|
) |
|
|
) |
|
|
self.camera_handles.append( |
|
|
self.server.add_camera_frustum( |
|
|
name=f"/frames/{step}/camera2_{step}", |
|
|
fov=2.0 * np.arctan(224.0 / 490.0), |
|
|
aspect=1.0, |
|
|
scale=self.camera_scale.value, |
|
|
color=(0, 0, 1.0), |
|
|
wxyz=rotation_matrix_to_quaternion(R21), |
|
|
position=T21, |
|
|
image=img2, |
|
|
) |
|
|
) |
|
|
|
|
|
def animate(self): |
|
|
with self.server.add_gui_folder("Playback"): |
|
|
gui_timestep = self.server.add_gui_slider( |
|
|
"Train Step", |
|
|
min=0, |
|
|
max=self.num_frames - 1, |
|
|
step=1, |
|
|
initial_value=0, |
|
|
disabled=True, |
|
|
) |
|
|
gui_next_frame = self.server.add_gui_button("Next Step", disabled=True) |
|
|
gui_prev_frame = self.server.add_gui_button("Prev Step", disabled=True) |
|
|
gui_playing = self.server.add_gui_checkbox("Playing", False) |
|
|
gui_framerate = self.server.add_gui_slider( |
|
|
"FPS", min=1, max=60, step=0.1, initial_value=1 |
|
|
) |
|
|
gui_framerate_options = self.server.add_gui_button_group( |
|
|
"FPS options", ("10", "20", "30", "60") |
|
|
) |
|
|
|
|
|
@gui_next_frame.on_click |
|
|
def _(_) -> None: |
|
|
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames |
|
|
|
|
|
@gui_prev_frame.on_click |
|
|
def _(_) -> None: |
|
|
gui_timestep.value = (gui_timestep.value - 1) % self.num_frames |
|
|
|
|
|
@gui_playing.on_update |
|
|
def _(_) -> None: |
|
|
gui_timestep.disabled = gui_playing.value |
|
|
gui_next_frame.disabled = gui_playing.value |
|
|
gui_prev_frame.disabled = gui_playing.value |
|
|
|
|
|
@gui_framerate_options.on_click |
|
|
def _(_) -> None: |
|
|
gui_framerate.value = int(gui_framerate_options.value) |
|
|
|
|
|
prev_timestep = gui_timestep.value |
|
|
|
|
|
@gui_timestep.on_update |
|
|
def _(_) -> None: |
|
|
nonlocal prev_timestep |
|
|
current_timestep = gui_timestep.value |
|
|
with self.server.atomic(): |
|
|
frame_nodes[current_timestep].visible = True |
|
|
frame_nodes[prev_timestep].visible = False |
|
|
prev_timestep = current_timestep |
|
|
self.server.flush() |
|
|
|
|
|
self.server.add_frame( |
|
|
"/frames", |
|
|
show_axes=False, |
|
|
) |
|
|
frame_nodes = [] |
|
|
for i in range(self.num_frames): |
|
|
step = self.all_steps[i] |
|
|
frame_nodes.append( |
|
|
self.server.add_frame( |
|
|
f"/frames/{step}", |
|
|
show_axes=False, |
|
|
) |
|
|
) |
|
|
self.add_pc(step) |
|
|
|
|
|
for i, frame_node in enumerate(frame_nodes): |
|
|
|
|
|
frame_node.visible = i == gui_timestep.value |
|
|
|
|
|
prev_timestep = gui_timestep.value |
|
|
while True: |
|
|
if gui_playing.value: |
|
|
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames |
|
|
for handle in self.camera_handles: |
|
|
handle.scale = self.camera_scale.value |
|
|
time.sleep(1.0 / gui_framerate.value) |
|
|
|
|
|
def run(self): |
|
|
self.animate() |
|
|
while True: |
|
|
time.sleep(10.0) |
|
|
|
|
|
|
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
|
|
|
def colorize_feature_map(x): |
|
|
""" |
|
|
Args: |
|
|
x: torch.Tensor, [B, H, W, D] |
|
|
Returns: |
|
|
torch.Tensor, [B, H, W, 3] |
|
|
""" |
|
|
device = x.device |
|
|
x = x.cpu().numpy() |
|
|
|
|
|
out = [] |
|
|
for x_ in x: |
|
|
x_ = colorize_feature_map_np(x_) |
|
|
out.append(torch.from_numpy(x_).to(device)) |
|
|
out = torch.stack(out).squeeze(0) |
|
|
return out |
|
|
|
|
|
|
|
|
def colorize_feature_map_np(x): |
|
|
""" |
|
|
Args: |
|
|
x: np.ndarray, [H, W, D] |
|
|
""" |
|
|
pca = PCA(n_components=3) |
|
|
pca_features = pca.fit_transform(x.reshape(-1, x.shape[-1])) |
|
|
|
|
|
pca_features = (pca_features - pca_features.min()) / ( |
|
|
pca_features.max() - pca_features.min() |
|
|
) |
|
|
pca_features = pca_features.reshape(x.shape[0], x.shape[1], 3) |
|
|
return pca_features |
|
|
|