|
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
|
|
|
import kornia |
|
|
from matplotlib import cm |
|
|
from torchvision.io import write_video |
|
|
from PIL import Image, ImageOps |
|
|
import os |
|
|
from typing import Union, Tuple, List |
|
|
import math |
|
|
|
|
|
|
|
|
from matplotlib import pyplot as plt |
|
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection |
|
|
|
|
|
DEFAULT_FOV_RAD = 0.9424777960769379 |
|
|
|
|
|
|
|
|
|
|
|
def get_default_intrinsics( |
|
|
fov_rad=DEFAULT_FOV_RAD, |
|
|
aspect_ratio=1.0, |
|
|
): |
|
|
if not isinstance(fov_rad, torch.Tensor): |
|
|
fov_rad = torch.tensor( |
|
|
[fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad |
|
|
) |
|
|
if aspect_ratio >= 1.0: |
|
|
focal_x = 0.5 / torch.tan(0.5 * fov_rad) |
|
|
focal_y = focal_x * aspect_ratio |
|
|
else: |
|
|
focal_y = 0.5 / torch.tan(0.5 * fov_rad) |
|
|
focal_x = focal_y / aspect_ratio |
|
|
intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3)) |
|
|
intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack( |
|
|
[focal_x, focal_y, torch.ones_like(focal_x)], dim=-1 |
|
|
) |
|
|
intrinsics[:, :, -1] = torch.tensor( |
|
|
[0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype |
|
|
) |
|
|
return intrinsics |
|
|
|
|
|
def to_hom(X): |
|
|
|
|
|
X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1) |
|
|
return X_hom |
|
|
|
|
|
|
|
|
def to_hom_pose(pose): |
|
|
|
|
|
if pose.shape[-2:] == (3, 4): |
|
|
pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1) |
|
|
pose_hom[:, :3, :] = pose |
|
|
return pose_hom |
|
|
return pose |
|
|
|
|
|
|
|
|
|
|
|
def get_image_grid(img_h, img_w): |
|
|
|
|
|
|
|
|
y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5) |
|
|
x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5) |
|
|
Y, X = torch.meshgrid(y_range, x_range, indexing="ij") |
|
|
xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) |
|
|
return to_hom(xy_grid) |
|
|
|
|
|
|
|
|
def img2cam(X, cam_intr): |
|
|
return X @ cam_intr.inverse().transpose(-1, -2) |
|
|
|
|
|
|
|
|
def cam2world(X, pose): |
|
|
X_hom = to_hom(X) |
|
|
pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4] |
|
|
return X_hom @ pose_inv.transpose(-1, -2) |
|
|
|
|
|
|
|
|
def get_center_and_ray(img_h, img_w, pose, intr): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grid_img = get_image_grid(img_h, img_w) |
|
|
grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) |
|
|
center_3D_cam = torch.zeros_like(grid_3D_cam) |
|
|
|
|
|
|
|
|
grid_3D = cam2world(grid_3D_cam, pose) |
|
|
center_3D = cam2world(center_3D_cam, pose) |
|
|
ray = grid_3D - center_3D |
|
|
|
|
|
return center_3D, ray, grid_3D_cam |
|
|
|
|
|
def get_plucker_coordinates( |
|
|
extrinsics_src, |
|
|
extrinsics, |
|
|
intrinsics=None, |
|
|
fov_rad=DEFAULT_FOV_RAD, |
|
|
target_size=[72, 72], |
|
|
): |
|
|
|
|
|
has_batch_dim = len(extrinsics.shape) == 4 |
|
|
|
|
|
if has_batch_dim: |
|
|
|
|
|
batch_size, num_cameras = extrinsics.shape[0:2] |
|
|
extrinsics_flat = extrinsics.reshape(-1, *extrinsics.shape[2:]) |
|
|
|
|
|
|
|
|
if len(extrinsics_src.shape) == 3: |
|
|
extrinsics_src_expanded = extrinsics_src.unsqueeze(1).expand(-1, num_cameras, -1, -1) |
|
|
extrinsics_src_flat = extrinsics_src_expanded.reshape(-1, *extrinsics_src.shape[1:]) |
|
|
else: |
|
|
extrinsics_src_flat = extrinsics_src.expand(batch_size * num_cameras, -1, -1) |
|
|
|
|
|
|
|
|
if intrinsics is None: |
|
|
intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device) |
|
|
intrinsics = intrinsics.expand(batch_size * num_cameras, -1, -1) |
|
|
elif len(intrinsics.shape) == 3: |
|
|
if intrinsics.shape[0] == num_cameras: |
|
|
intrinsics = intrinsics.expand(batch_size, -1, -1, -1).reshape(-1, *intrinsics.shape[1:]) |
|
|
else: |
|
|
intrinsics = intrinsics.expand(batch_size * num_cameras, -1, -1) |
|
|
elif len(intrinsics.shape) == 4: |
|
|
intrinsics = intrinsics.reshape(-1, *intrinsics.shape[2:]) |
|
|
else: |
|
|
|
|
|
extrinsics_flat = extrinsics |
|
|
extrinsics_src_flat = extrinsics_src |
|
|
if intrinsics is None: |
|
|
intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device) |
|
|
|
|
|
|
|
|
if not ( |
|
|
torch.all(intrinsics[:, :2, -1] >= 0) |
|
|
and torch.all(intrinsics[:, :2, -1] <= 1) |
|
|
): |
|
|
intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8 |
|
|
|
|
|
|
|
|
assert ( |
|
|
torch.all(intrinsics[:, :2, -1] >= 0) |
|
|
and torch.all(intrinsics[:, :2, -1] <= 1) |
|
|
), "Intrinsics should be expressed in resolution-independent normalized image coordinates." |
|
|
|
|
|
c2w_src = torch.linalg.inv(extrinsics_src_flat) |
|
|
|
|
|
extrinsics_rel = torch.einsum( |
|
|
"vnm,vmp->vnp", extrinsics_flat, c2w_src |
|
|
) |
|
|
|
|
|
intrinsics[:, :2] *= extrinsics_flat.new_tensor( |
|
|
[ |
|
|
target_size[1], |
|
|
target_size[0], |
|
|
] |
|
|
).view(1, -1, 1) |
|
|
|
|
|
centers, rays, grid_cam = get_center_and_ray( |
|
|
img_h=target_size[0], |
|
|
img_w=target_size[1], |
|
|
pose=extrinsics_rel[:, :3, :], |
|
|
intr=intrinsics, |
|
|
) |
|
|
|
|
|
rays = torch.nn.functional.normalize(rays, dim=-1) |
|
|
plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1) |
|
|
plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size) |
|
|
|
|
|
|
|
|
if has_batch_dim: |
|
|
plucker = plucker.reshape(batch_size, num_cameras, *plucker.shape[1:]) |
|
|
|
|
|
return plucker |
|
|
|
|
|
|
|
|
def get_value_dict( |
|
|
curr_imgs, |
|
|
curr_imgs_clip, |
|
|
curr_input_frame_indices, |
|
|
curr_c2ws, |
|
|
curr_Ks, |
|
|
curr_input_camera_indices, |
|
|
all_c2ws, |
|
|
camera_scale, |
|
|
): |
|
|
assert sorted(curr_input_camera_indices) == sorted( |
|
|
range(len(curr_input_camera_indices)) |
|
|
) |
|
|
H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8 |
|
|
|
|
|
value_dict = {} |
|
|
value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices] |
|
|
value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs) |
|
|
value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool) |
|
|
value_dict["cond_frames_mask"][curr_input_frame_indices] = True |
|
|
value_dict["cond_aug"] = 0.0 |
|
|
|
|
|
if curr_c2ws.shape[-1] == 3: |
|
|
c2w = to_hom_pose(curr_c2ws.float()) |
|
|
else: |
|
|
c2w = curr_c2ws |
|
|
w2c = torch.linalg.inv(c2w) |
|
|
|
|
|
|
|
|
ref_c2ws = all_c2ws |
|
|
camera_dist_2med = torch.norm( |
|
|
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values, |
|
|
dim=-1, |
|
|
) |
|
|
valid_mask = camera_dist_2med <= torch.clamp( |
|
|
torch.quantile(camera_dist_2med, 0.97) * 10, |
|
|
max=1e6, |
|
|
) |
|
|
c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True) |
|
|
w2c = torch.linalg.inv(c2w) |
|
|
|
|
|
|
|
|
camera_dists = c2w[:, :3, 3].clone() |
|
|
translation_scaling_factor = ( |
|
|
camera_scale |
|
|
if torch.isclose( |
|
|
torch.norm(camera_dists[0]), |
|
|
torch.zeros(1), |
|
|
atol=1e-5, |
|
|
).any() |
|
|
else (camera_scale / torch.norm(camera_dists[0])) |
|
|
) |
|
|
w2c[:, :3, 3] *= translation_scaling_factor |
|
|
c2w[:, :3, 3] *= translation_scaling_factor |
|
|
value_dict["plucker_coordinate"] = get_plucker_coordinates( |
|
|
extrinsics_src=w2c[0], |
|
|
extrinsics=w2c, |
|
|
intrinsics=curr_Ks.float().clone(), |
|
|
target_size=(H // F, W // F), |
|
|
) |
|
|
|
|
|
value_dict["c2w"] = c2w |
|
|
value_dict["K"] = curr_Ks |
|
|
value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool) |
|
|
value_dict["camera_mask"][curr_input_camera_indices] = True |
|
|
|
|
|
return value_dict |
|
|
|
|
|
def parse_meta_data(file_path, image_height=288, image_width=512): |
|
|
with open(file_path, 'r') as file: |
|
|
lines = file.readlines() |
|
|
|
|
|
|
|
|
video_url = lines[0].strip() |
|
|
|
|
|
line = lines[1] |
|
|
data = line.strip().split() |
|
|
|
|
|
focal_length_x = float(data[1]) |
|
|
focal_length_y = float(data[2]) |
|
|
principal_point_x = float(data[3]) |
|
|
principal_point_y = float(data[4]) |
|
|
|
|
|
|
|
|
|
|
|
original_K = [ |
|
|
[focal_length_x, 0, principal_point_x], |
|
|
[0, focal_length_y, principal_point_y], |
|
|
[0, 0, 1] |
|
|
] |
|
|
|
|
|
K = [ |
|
|
[focal_length_x * image_width, 0, principal_point_x * image_width], |
|
|
[0, focal_length_y * image_height, principal_point_y * image_height], |
|
|
[0, 0, 1] |
|
|
] |
|
|
|
|
|
|
|
|
timestamp_to_c2ws = {} |
|
|
timestamps = [] |
|
|
|
|
|
for line in lines[1:]: |
|
|
data = line.strip().split() |
|
|
timestamp = int(data[0]) |
|
|
R_t = [float(x) for x in data[7:]] |
|
|
P = [ |
|
|
R_t[0:4], |
|
|
R_t[4:8], |
|
|
R_t[8:12], |
|
|
[0, 0, 0, 1] |
|
|
] |
|
|
timestamp_to_c2ws[timestamp] = np.array(P) |
|
|
timestamps.append(timestamp) |
|
|
return timestamps, np.array(K), timestamp_to_c2ws, original_K |
|
|
|
|
|
|
|
|
def get_wh_with_fixed_shortest_side(w, h, size): |
|
|
|
|
|
if size is None or size <= 0: |
|
|
return w, h |
|
|
if w < h: |
|
|
new_w = size |
|
|
new_h = int(size * h / w) |
|
|
else: |
|
|
new_h = size |
|
|
new_w = int(size * w / h) |
|
|
return new_w, new_h |
|
|
|
|
|
def get_resizing_factor( |
|
|
target_shape: Tuple[int, int], |
|
|
current_shape: Tuple[int, int], |
|
|
cover_target: bool = True, |
|
|
|
|
|
|
|
|
) -> float: |
|
|
r_bound = target_shape[1] / target_shape[0] |
|
|
aspect_r = current_shape[1] / current_shape[0] |
|
|
if r_bound >= 1.0: |
|
|
if cover_target: |
|
|
if aspect_r >= r_bound: |
|
|
factor = min(target_shape) / min(current_shape) |
|
|
elif aspect_r < 1.0: |
|
|
factor = max(target_shape) / min(current_shape) |
|
|
else: |
|
|
factor = max(target_shape) / max(current_shape) |
|
|
else: |
|
|
if aspect_r >= r_bound: |
|
|
factor = max(target_shape) / max(current_shape) |
|
|
elif aspect_r < 1.0: |
|
|
factor = min(target_shape) / max(current_shape) |
|
|
else: |
|
|
factor = min(target_shape) / min(current_shape) |
|
|
else: |
|
|
if cover_target: |
|
|
if aspect_r <= r_bound: |
|
|
factor = min(target_shape) / min(current_shape) |
|
|
elif aspect_r > 1.0: |
|
|
factor = max(target_shape) / min(current_shape) |
|
|
else: |
|
|
factor = max(target_shape) / max(current_shape) |
|
|
else: |
|
|
if aspect_r <= r_bound: |
|
|
factor = max(target_shape) / max(current_shape) |
|
|
elif aspect_r > 1.0: |
|
|
factor = min(target_shape) / max(current_shape) |
|
|
else: |
|
|
factor = min(target_shape) / min(current_shape) |
|
|
return factor |
|
|
|
|
|
def transform_img_and_K( |
|
|
image: torch.Tensor, |
|
|
size: Union[int, Tuple[int, int]], |
|
|
scale: float = 1.0, |
|
|
center: Tuple[float, float] = (0.5, 0.5), |
|
|
K: Union[torch.Tensor, np.ndarray, None] = None, |
|
|
size_stride: int = 1, |
|
|
mode: str = "crop", |
|
|
): |
|
|
assert mode in [ |
|
|
"crop", |
|
|
"pad", |
|
|
"stretch", |
|
|
], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}" |
|
|
|
|
|
h, w = image.shape[-2:] |
|
|
if isinstance(size, (tuple, list)): |
|
|
|
|
|
|
|
|
W, H = size |
|
|
else: |
|
|
|
|
|
|
|
|
W, H = get_wh_with_fixed_shortest_side(w, h, size) |
|
|
W, H = ( |
|
|
math.floor(W / size_stride + 0.5) * size_stride, |
|
|
math.floor(H / size_stride + 0.5) * size_stride, |
|
|
) |
|
|
|
|
|
if mode == "stretch": |
|
|
rh, rw = H, W |
|
|
else: |
|
|
rfs = get_resizing_factor( |
|
|
(H, W), |
|
|
(h, w), |
|
|
cover_target=mode != "pad", |
|
|
) |
|
|
(rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)] |
|
|
|
|
|
rh, rw = int(rh / scale), int(rw / scale) |
|
|
image = torch.nn.functional.interpolate( |
|
|
image, (rh, rw), mode="area", antialias=False |
|
|
) |
|
|
|
|
|
cy_center = int(center[1] * image.shape[-2]) |
|
|
cx_center = int(center[0] * image.shape[-1]) |
|
|
if mode != "pad": |
|
|
ct = max(0, cy_center - H // 2) |
|
|
cl = max(0, cx_center - W // 2) |
|
|
ct = min(ct, image.shape[-2] - H) |
|
|
cl = min(cl, image.shape[-1] - W) |
|
|
image = TF.crop(image, top=ct, left=cl, height=H, width=W) |
|
|
pl, pt = 0, 0 |
|
|
else: |
|
|
pt = max(0, H // 2 - cy_center) |
|
|
pl = max(0, W // 2 - cx_center) |
|
|
pb = max(0, H - pt - image.shape[-2]) |
|
|
pr = max(0, W - pl - image.shape[-1]) |
|
|
image = TF.pad( |
|
|
image, |
|
|
[pl, pt, pr, pb], |
|
|
) |
|
|
cl, ct = 0, 0 |
|
|
|
|
|
if K is not None: |
|
|
K = K.clone() |
|
|
|
|
|
if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1): |
|
|
K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] |
|
|
else: |
|
|
K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] |
|
|
K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct]) |
|
|
|
|
|
return image, K |
|
|
|
|
|
|
|
|
def load_img_and_K( |
|
|
image_path_or_size: Union[str, torch.Size], |
|
|
size: Optional[Union[int, Tuple[int, int]]], |
|
|
scale: float = 1.0, |
|
|
center: Tuple[float, float] = (0.5, 0.5), |
|
|
K: Union[torch.Tensor, np.ndarray, None] = None, |
|
|
size_stride: int = 1, |
|
|
center_crop: bool = False, |
|
|
image_as_tensor: bool = True, |
|
|
context_rgb: Union[np.ndarray, None] = None, |
|
|
device: str = "cuda", |
|
|
): |
|
|
if isinstance(image_path_or_size, torch.Size): |
|
|
image = Image.new("RGBA", image_path_or_size[::-1]) |
|
|
else: |
|
|
image = Image.open(image_path_or_size).convert("RGBA") |
|
|
|
|
|
w, h = image.size |
|
|
if size is None: |
|
|
size = (w, h) |
|
|
|
|
|
image = np.array(image).astype(np.float32) / 255 |
|
|
if image.shape[-1] == 4: |
|
|
rgb, alpha = image[:, :, :3], image[:, :, 3:] |
|
|
if context_rgb is not None: |
|
|
image = rgb * alpha + context_rgb * (1 - alpha) |
|
|
else: |
|
|
image = rgb * alpha + (1 - alpha) |
|
|
image = image.transpose(2, 0, 1) |
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) |
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
if isinstance(size, (tuple, list)): |
|
|
|
|
|
|
|
|
W, H = size |
|
|
else: |
|
|
|
|
|
|
|
|
W, H = get_wh_with_fixed_shortest_side(w, h, size) |
|
|
W, H = ( |
|
|
math.floor(W / size_stride + 0.5) * size_stride, |
|
|
math.floor(H / size_stride + 0.5) * size_stride, |
|
|
) |
|
|
|
|
|
rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w)) |
|
|
resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)] |
|
|
image = torch.nn.functional.interpolate( |
|
|
image, resize_size, mode="area", antialias=False |
|
|
) |
|
|
if scale < 1.0: |
|
|
pw = math.ceil((W - resize_size[1]) * 0.5) |
|
|
ph = math.ceil((H - resize_size[0]) * 0.5) |
|
|
image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0) |
|
|
|
|
|
cy_center = int(center[1] * image.shape[-2]) |
|
|
cx_center = int(center[0] * image.shape[-1]) |
|
|
if center_crop: |
|
|
side = min(H, W) |
|
|
ct = max(0, cy_center - side // 2) |
|
|
cl = max(0, cx_center - side // 2) |
|
|
ct = min(ct, image.shape[-2] - side) |
|
|
cl = min(cl, image.shape[-1] - side) |
|
|
image = TF.crop(image, top=ct, left=cl, height=side, width=side) |
|
|
else: |
|
|
ct = max(0, cy_center - H // 2) |
|
|
cl = max(0, cx_center - W // 2) |
|
|
ct = min(ct, image.shape[-2] - H) |
|
|
cl = min(cl, image.shape[-1] - W) |
|
|
image = TF.crop(image, top=ct, left=cl, height=H, width=W) |
|
|
|
|
|
if K is not None: |
|
|
K = K.clone() |
|
|
if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1): |
|
|
K[:2] *= K.new_tensor([rw, rh])[:, None] |
|
|
else: |
|
|
K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] |
|
|
K[:2, 2] -= K.new_tensor([cl, ct]) |
|
|
|
|
|
if image_as_tensor: |
|
|
|
|
|
image = image.to(device) * 2.0 - 1.0 |
|
|
else: |
|
|
|
|
|
image = image.permute(0, 2, 3, 1).numpy()[0] |
|
|
image = Image.fromarray((image * 255).astype(np.uint8)) |
|
|
return image, K |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def geodesic_distance(extrinsic1: Union[np.ndarray, torch.Tensor], |
|
|
extrinsic2: Union[np.ndarray, torch.Tensor], |
|
|
weight_translation: float = 0.01,): |
|
|
""" |
|
|
Computes the geodesic distance between two camera poses in SE(3). |
|
|
|
|
|
Parameters: |
|
|
extrinsic1 (Union[np.ndarray, torch.Tensor]): 4x4 extrinsic matrix of the first pose. |
|
|
extrinsic2 (Union[np.ndarray, torch.Tensor]): 4x4 extrinsic matrix of the second pose. |
|
|
|
|
|
Returns: |
|
|
Union[float, torch.Tensor]: Geodesic distance between the two poses. |
|
|
""" |
|
|
if torch.is_tensor(extrinsic1): |
|
|
|
|
|
R1 = extrinsic1[:3, :3] |
|
|
t1 = extrinsic1[:3, 3] |
|
|
R2 = extrinsic2[:3, :3] |
|
|
t2 = extrinsic2[:3, 3] |
|
|
|
|
|
|
|
|
translation_distance = torch.norm(t1 - t2) |
|
|
|
|
|
|
|
|
R_relative = torch.matmul(R1.T, R2) |
|
|
|
|
|
|
|
|
trace_value = torch.trace(R_relative) |
|
|
|
|
|
trace_value = torch.clamp(trace_value, -1.0, 3.0) |
|
|
angular_distance = torch.acos((trace_value - 1) / 2) |
|
|
|
|
|
else: |
|
|
|
|
|
R1 = extrinsic1[:3, :3] |
|
|
t1 = extrinsic1[:3, 3] |
|
|
R2 = extrinsic2[:3, :3] |
|
|
t2 = extrinsic2[:3, 3] |
|
|
|
|
|
|
|
|
translation_distance = np.linalg.norm(t1 - t2) |
|
|
|
|
|
|
|
|
R_relative = np.dot(R1.T, R2) |
|
|
|
|
|
|
|
|
trace_value = np.trace(R_relative) |
|
|
|
|
|
trace_value = np.clip(trace_value, -1.0, 3.0) |
|
|
angular_distance = np.arccos((trace_value - 1) / 2) |
|
|
|
|
|
|
|
|
geodesic_dist = translation_distance*weight_translation + angular_distance |
|
|
|
|
|
return geodesic_dist |
|
|
|
|
|
|
|
|
def inverse_geodesic_distance(extrinsic1, |
|
|
extrinsic2, |
|
|
weight_translation=0.01): |
|
|
""" |
|
|
Computes the inverse geodesic distance between two camera poses in SE(3). |
|
|
|
|
|
Parameters: |
|
|
extrinsic1 (np.ndarray): 4x4 extrinsic matrix of the first pose. |
|
|
extrinsic2 (np.ndarray): 4x4 extrinsic matrix of the second pose. |
|
|
|
|
|
Returns: |
|
|
float: Inverse geodesic distance between the two poses. |
|
|
""" |
|
|
|
|
|
geodesic_dist = geodesic_distance(extrinsic1, extrinsic2, weight_translation) |
|
|
|
|
|
|
|
|
inverse_geodesic_dist = 1.0 / (geodesic_dist + 1e-6) |
|
|
|
|
|
return inverse_geodesic_dist |
|
|
|
|
|
|
|
|
|
|
|
def average_camera_pose(camera_poses): |
|
|
""" |
|
|
Compute a better average of camera poses in SE(3). |
|
|
|
|
|
Args: |
|
|
camera_poses: List or array of camera poses, each a 4x4 matrix |
|
|
|
|
|
Returns: |
|
|
Average camera pose as a 4x4 matrix |
|
|
""" |
|
|
|
|
|
rotations = camera_poses[:, :3, :3].detach().cpu().numpy() |
|
|
translations = camera_poses[:, :3, 3].detach().cpu().numpy() |
|
|
|
|
|
|
|
|
avg_translation = np.mean(translations, axis=0) |
|
|
|
|
|
|
|
|
import scipy.spatial.transform as transform |
|
|
quats = [transform.Rotation.from_matrix(R).as_quat() for R in rotations] |
|
|
|
|
|
|
|
|
for i in range(1, len(quats)): |
|
|
if np.dot(quats[0], quats[i]) < 0: |
|
|
quats[i] = -quats[i] |
|
|
|
|
|
|
|
|
avg_quat = np.mean(quats, axis=0) |
|
|
avg_quat = avg_quat / np.linalg.norm(avg_quat) |
|
|
avg_rotation = transform.Rotation.from_quat(avg_quat).as_matrix() |
|
|
|
|
|
|
|
|
avg_pose = np.eye(4) |
|
|
avg_pose[:3, :3] = avg_rotation |
|
|
avg_pose[:3, 3] = avg_translation |
|
|
|
|
|
return avg_pose |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_image( |
|
|
image, |
|
|
image_encoder, |
|
|
device, |
|
|
dtype, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
image_embeddings = image_encoder(image) |
|
|
|
|
|
|
|
|
return image_embeddings |
|
|
|
|
|
|
|
|
def encode_vae_image( |
|
|
image, |
|
|
vae, |
|
|
device, |
|
|
dtype, |
|
|
|
|
|
): |
|
|
image = image.to(device=device, dtype=dtype) |
|
|
image_latents = vae.encode(image, 1) |
|
|
|
|
|
|
|
|
return image_latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def do_sample( |
|
|
model, |
|
|
ae, |
|
|
denoiser, |
|
|
sampler, |
|
|
c, |
|
|
uc, |
|
|
c2w, |
|
|
K, |
|
|
cond_frames_mask, |
|
|
H=576, |
|
|
W=768, |
|
|
C=4, |
|
|
F=8, |
|
|
T=8, |
|
|
cfg=2.0, |
|
|
decoding_t=1, |
|
|
verbose=True, |
|
|
global_pbar=None, |
|
|
return_latents=False, |
|
|
device: str = "cuda", |
|
|
**_, |
|
|
): |
|
|
|
|
|
num_samples = [1, T] |
|
|
with torch.inference_mode(), torch.autocast("cuda"): |
|
|
|
|
|
additional_model_inputs = {"num_frames": T} |
|
|
additional_sampler_inputs = { |
|
|
"c2w": c2w.to("cuda"), |
|
|
"K": K.to("cuda"), |
|
|
"input_frame_mask": cond_frames_mask.to("cuda"), |
|
|
} |
|
|
if global_pbar is not None: |
|
|
additional_sampler_inputs["global_pbar"] = global_pbar |
|
|
|
|
|
shape = (math.prod(num_samples), C, H // F, W // F) |
|
|
randn = torch.randn(shape).to(device) |
|
|
|
|
|
samples_z = sampler( |
|
|
lambda input, sigma, c: denoiser( |
|
|
model, |
|
|
input, |
|
|
sigma, |
|
|
c, |
|
|
**additional_model_inputs, |
|
|
), |
|
|
randn, |
|
|
scale=cfg, |
|
|
cond=c, |
|
|
uc=uc, |
|
|
verbose=verbose, |
|
|
**additional_sampler_inputs, |
|
|
) |
|
|
if samples_z is None: |
|
|
return |
|
|
|
|
|
samples = ae.decode(samples_z, decoding_t) |
|
|
if return_latents: |
|
|
return samples, samples_z |
|
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
def decode_output( |
|
|
samples, |
|
|
T, |
|
|
indices=None, |
|
|
): |
|
|
|
|
|
if isinstance(samples, dict): |
|
|
|
|
|
for sample, value in samples.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
value = value.detach().cpu() |
|
|
elif isinstance(value, np.ndarray): |
|
|
value = torch.from_numpy(value) |
|
|
else: |
|
|
value = torch.tensor(value) |
|
|
|
|
|
if indices is not None and value.shape[0] == T: |
|
|
value = value[indices] |
|
|
samples[sample] = value |
|
|
else: |
|
|
|
|
|
samples = samples.detach().cpu() |
|
|
|
|
|
if indices is not None and samples.shape[0] == T: |
|
|
samples = samples[indices] |
|
|
samples = {"samples-rgb/image": samples} |
|
|
|
|
|
return samples |
|
|
|
|
|
def select_frames(timestamps, min_num_frames=2, skip_frame=10, random_start=False): |
|
|
""" |
|
|
Select frames from a video sequence based on defined criteria. |
|
|
|
|
|
Args: |
|
|
timestamps: List of timestamps for the frames |
|
|
min_num_frames: Minimum number of frames required |
|
|
skip_frame: Number of frames to skip between selections |
|
|
random_start: If True, start from a random frame |
|
|
|
|
|
Returns: |
|
|
tuple: (selected_frame_indices, selected_frame_timestamps) or (None, None) if criteria not met |
|
|
""" |
|
|
|
|
|
num_frames = len(timestamps) |
|
|
if num_frames < min_num_frames: |
|
|
print(f"[Worker PID={os.getpid()}] Episode has less than {min_num_frames} frames") |
|
|
return None, None |
|
|
|
|
|
|
|
|
if num_frames < 2: |
|
|
print(f"[Worker PID={os.getpid()}] Episode has less than 2 frames") |
|
|
return None, None |
|
|
elif num_frames < skip_frame: |
|
|
cur_skip_frame = num_frames - 1 |
|
|
else: |
|
|
cur_skip_frame = skip_frame |
|
|
|
|
|
if random_start: |
|
|
start_frame = np.random.randint(0, skip_frame) |
|
|
else: |
|
|
start_frame = 0 |
|
|
|
|
|
|
|
|
selected_frame_indices = list(range(start_frame, num_frames, cur_skip_frame)) |
|
|
selected_frame_timestamps = [timestamps[i] for i in selected_frame_indices] |
|
|
|
|
|
return selected_frame_indices, selected_frame_timestamps |
|
|
|
|
|
|
|
|
def tensor2im(input_image, imtype=np.uint8): |
|
|
if not isinstance(input_image, np.ndarray): |
|
|
if isinstance(input_image, torch.Tensor): |
|
|
image_tensor = input_image.data |
|
|
else: |
|
|
return input_image |
|
|
image_numpy = image_tensor[0].clamp(0.0, 1.0).cpu().float().numpy() |
|
|
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 |
|
|
else: |
|
|
image_numpy = input_image |
|
|
return image_numpy.astype(imtype) |
|
|
|
|
|
|
|
|
class LatentStorer: |
|
|
def __init__(self): |
|
|
self.latent = None |
|
|
|
|
|
def __call__(self, i, t, latent): |
|
|
self.latent = latent |
|
|
|
|
|
|
|
|
def sobel_filter(disp, mode="sobel", beta=10.0): |
|
|
sobel_grad = kornia.filters.spatial_gradient(disp, mode=mode, normalized=False) |
|
|
sobel_mag = torch.sqrt(sobel_grad[:, :, 0, Ellipsis] ** 2 + sobel_grad[:, :, 1, Ellipsis] ** 2) |
|
|
alpha = torch.exp(-1.0 * beta * sobel_mag).detach() |
|
|
|
|
|
return alpha |
|
|
|
|
|
|
|
|
def apply_colormap(image, cmap="viridis"): |
|
|
colormap = cm.get_cmap(cmap) |
|
|
colormap = torch.tensor(colormap.colors).to(image.device) |
|
|
image_long = (image * 255).long() |
|
|
image_long_min = torch.min(image_long) |
|
|
image_long_max = torch.max(image_long) |
|
|
assert image_long_min >= 0, f"the min value is {image_long_min}" |
|
|
assert image_long_max <= 255, f"the max value is {image_long_max}" |
|
|
return colormap[image_long[..., 0]] |
|
|
|
|
|
|
|
|
def apply_depth_colormap( |
|
|
depth, |
|
|
near_plane=None, |
|
|
far_plane=None, |
|
|
cmap="viridis", |
|
|
): |
|
|
near_plane = near_plane or float(torch.min(depth)) |
|
|
far_plane = far_plane or float(torch.max(depth)) |
|
|
|
|
|
depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
|
|
depth = torch.clip(depth, 0, 1) |
|
|
|
|
|
colored_image = apply_colormap(depth, cmap=cmap) |
|
|
|
|
|
return colored_image |
|
|
|
|
|
|
|
|
def save_video(video, path, fps=10): |
|
|
video = video.permute(0, 2, 3, 1) |
|
|
video_codec = "libx264" |
|
|
video_options = { |
|
|
"crf": "23", |
|
|
"preset": "slow", |
|
|
} |
|
|
write_video(str(path), video, fps=fps, video_codec=video_codec, options=video_options) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_camera_poses(camera_poses, axis_length=0.1): |
|
|
""" |
|
|
Visualizes a set of camera poses in 3D using Matplotlib. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
camera_poses : np.ndarray |
|
|
An array of shape (N, 4, 4) containing N camera poses. |
|
|
Each pose is a 4x4 transformation matrix. |
|
|
axis_length : float |
|
|
Length of the camera axes to draw. |
|
|
""" |
|
|
if isinstance(camera_poses, torch.Tensor): |
|
|
camera_poses = camera_poses.detach().cpu().numpy() |
|
|
|
|
|
fig = plt.figure() |
|
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
|
|
|
|
|
for i in range(camera_poses.shape[0]): |
|
|
|
|
|
R = camera_poses[i][:3, :3] |
|
|
t = camera_poses[i][:3, 3] |
|
|
|
|
|
|
|
|
ax.scatter(t[0], t[1], t[2], c='k', marker='o', s=20) |
|
|
|
|
|
|
|
|
x_axis_end = t + R[:, 0] * axis_length |
|
|
y_axis_end = t + R[:, 1] * axis_length |
|
|
z_axis_end = t + R[:, 2] * axis_length |
|
|
|
|
|
|
|
|
ax.plot([t[0], x_axis_end[0]], [t[1], x_axis_end[1]], |
|
|
[t[2], x_axis_end[2]], color='r') |
|
|
ax.plot([t[0], y_axis_end[0]], [t[1], y_axis_end[1]], |
|
|
[t[2], y_axis_end[2]], color='g') |
|
|
ax.plot([t[0], z_axis_end[0]], [t[1], z_axis_end[1]], |
|
|
[t[2], z_axis_end[2]], color='b') |
|
|
|
|
|
|
|
|
set_axes_equal(ax) |
|
|
|
|
|
ax.set_title("Camera Poses Visualization") |
|
|
ax.set_xlabel("X") |
|
|
ax.set_ylabel("Y") |
|
|
ax.set_zlabel("Z") |
|
|
plt.show() |
|
|
|
|
|
def set_axes_equal(ax): |
|
|
""" |
|
|
Make axes of 3D plot have equal scale so that spheres appear as spheres, cubes as cubes, etc. |
|
|
This is a workaround to Matplotlib's set_aspect('equal') which is not supported in 3D. |
|
|
""" |
|
|
x_limits = ax.get_xlim3d() |
|
|
y_limits = ax.get_ylim3d() |
|
|
z_limits = ax.get_zlim3d() |
|
|
|
|
|
x_range = x_limits[1] - x_limits[0] |
|
|
y_range = y_limits[1] - y_limits[0] |
|
|
z_range = z_limits[1] - z_limits[0] |
|
|
|
|
|
max_range = max(x_range, y_range, z_range) |
|
|
x_middle = np.mean(x_limits) |
|
|
y_middle = np.mean(y_limits) |
|
|
z_middle = np.mean(z_limits) |
|
|
|
|
|
ax.set_xlim3d([x_middle - 0.5 * max_range, x_middle + 0.5 * max_range]) |
|
|
ax.set_ylim3d([y_middle - 0.5 * max_range, y_middle + 0.5 * max_range]) |
|
|
ax.set_zlim3d([z_middle - 0.5 * max_range, z_middle + 0.5 * max_range]) |
|
|
|
|
|
|
|
|
def tensor_to_pil(image): |
|
|
if isinstance(image, torch.Tensor): |
|
|
if image.dim() == 4: |
|
|
image = image.squeeze(0) |
|
|
image = image.permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if image.min() < -0.1: |
|
|
image = (image + 1) / 2.0 |
|
|
|
|
|
|
|
|
image = (image * 255) |
|
|
image = np.clip(image, 0, 255) |
|
|
image = image.astype(np.uint8) |
|
|
return Image.fromarray(image) |
|
|
|
|
|
|
|
|
|
|
|
def center_crop_pil_image(input_image, target_width=1024, target_height=576): |
|
|
w, h = input_image.size |
|
|
h_ratio = h / target_height |
|
|
w_ratio = w / target_width |
|
|
|
|
|
if h_ratio > w_ratio: |
|
|
h = int(h / w_ratio) |
|
|
if h < target_height: |
|
|
h = target_height |
|
|
input_image = input_image.resize((target_width, h), Image.Resampling.LANCZOS) |
|
|
else: |
|
|
w = int(w / h_ratio) |
|
|
if w < target_width: |
|
|
w = target_width |
|
|
input_image = input_image.resize((w, target_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
return ImageOps.fit(input_image, (target_width, target_height), Image.BICUBIC) |
|
|
|
|
|
def resize_pil_image(img, long_edge_size): |
|
|
S = max(img.size) |
|
|
if S > long_edge_size: |
|
|
interp = PIL.Image.LANCZOS |
|
|
elif S <= long_edge_size: |
|
|
interp = PIL.Image.BICUBIC |
|
|
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) |
|
|
return img.resize(new_size, interp) |
|
|
|
|
|
def visualize_surfels( |
|
|
surfels, |
|
|
draw_normals=False, |
|
|
normal_scale=20, |
|
|
disk_resolution=16, |
|
|
disk_alpha=0.5 |
|
|
): |
|
|
""" |
|
|
Visualize surfels as 2D disks oriented by their normals in 3D using matplotlib. |
|
|
|
|
|
Args: |
|
|
surfels (list of Surfel): Each Surfel has at least: |
|
|
- position: (x, y, z) |
|
|
- normal: (nx, ny, nz) |
|
|
- radius: scalar |
|
|
- color: (R, G, B) in [0..255] (optional) |
|
|
draw_normals (bool): If True, draws the surfel normals as quiver arrows. |
|
|
normal_scale (float): Scale factor for the normal arrows. |
|
|
disk_resolution (int): Number of segments to approximate each disk. |
|
|
disk_alpha (float): Alpha (transparency) for the filled disks. |
|
|
""" |
|
|
|
|
|
fig = plt.figure() |
|
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
|
|
|
|
|
positions = [] |
|
|
normals = [] |
|
|
|
|
|
|
|
|
polygons = [] |
|
|
polygon_colors = [] |
|
|
|
|
|
for s in surfels: |
|
|
|
|
|
|
|
|
position = s.position |
|
|
normal = s.normal |
|
|
radius = s.radius |
|
|
|
|
|
if isinstance(position, torch.Tensor): |
|
|
x, y, z = position.detach().cpu().numpy() |
|
|
nx, ny, nz = normal.detach().cpu().numpy() |
|
|
radius = radius.detach().cpu().numpy() |
|
|
else: |
|
|
x, y, z = position |
|
|
nx, ny, nz = normal |
|
|
radius = radius |
|
|
|
|
|
|
|
|
|
|
|
if s.color is None: |
|
|
color = (0.2, 0.6, 1.0) |
|
|
else: |
|
|
r, g, b = s.color |
|
|
color = (r/255.0, g/255.0, b/255.0) |
|
|
|
|
|
|
|
|
normal = np.array([nx, ny, nz], dtype=float) |
|
|
norm_len = np.linalg.norm(normal) |
|
|
|
|
|
if norm_len < 1e-12: |
|
|
continue |
|
|
normal /= norm_len |
|
|
|
|
|
|
|
|
|
|
|
up = np.array([0, 0, 1], dtype=float) |
|
|
if abs(normal.dot(up)) > 0.9: |
|
|
up = np.array([0, 1, 0], dtype=float) |
|
|
|
|
|
|
|
|
xAxis = np.cross(normal, up) |
|
|
xAxis /= np.linalg.norm(xAxis) |
|
|
|
|
|
yAxis = np.cross(normal, xAxis) |
|
|
yAxis /= np.linalg.norm(yAxis) |
|
|
|
|
|
|
|
|
angles = np.linspace(0, 2*np.pi, disk_resolution, endpoint=False) |
|
|
circle_points_3d = [] |
|
|
for theta in angles: |
|
|
|
|
|
px = radius * np.cos(theta) |
|
|
py = radius * np.sin(theta) |
|
|
|
|
|
|
|
|
world_pt = np.array([x, y, z]) + px * xAxis + py * yAxis |
|
|
circle_points_3d.append(world_pt) |
|
|
|
|
|
|
|
|
|
|
|
circle_points_3d = np.array(circle_points_3d) |
|
|
polygons.append(circle_points_3d) |
|
|
polygon_colors.append(color) |
|
|
|
|
|
|
|
|
positions.append([x, y, z]) |
|
|
normals.append(normal) |
|
|
|
|
|
|
|
|
poly_collection = Poly3DCollection( |
|
|
polygons, |
|
|
facecolors=polygon_colors, |
|
|
edgecolors='k', |
|
|
linewidths=0.5, |
|
|
alpha=disk_alpha |
|
|
) |
|
|
ax.add_collection3d(poly_collection) |
|
|
|
|
|
|
|
|
if draw_normals and len(positions) > 0: |
|
|
X = [p[0] for p in positions] |
|
|
Y = [p[1] for p in positions] |
|
|
Z = [p[2] for p in positions] |
|
|
|
|
|
Nx = [n[0] for n in normals] |
|
|
Ny = [n[1] for n in normals] |
|
|
Nz = [n[2] for n in normals] |
|
|
|
|
|
|
|
|
ax.quiver( |
|
|
X, Y, Z, |
|
|
Nx, Ny, Nz, |
|
|
length=normal_scale, |
|
|
color='red', |
|
|
normalize=True |
|
|
) |
|
|
|
|
|
|
|
|
ax.set_xlabel('X') |
|
|
ax.set_ylabel('Y') |
|
|
ax.set_zlabel('Z') |
|
|
try: |
|
|
ax.set_box_aspect((1, 1, 1)) |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
plt.title("Surfels as Disks (Oriented by Normal)") |
|
|
plt.show() |
|
|
|
|
|
def visualize_pointcloud( |
|
|
points, |
|
|
colors=None, |
|
|
title='Point Cloud', |
|
|
point_size=1, |
|
|
alpha=1.0 |
|
|
): |
|
|
""" |
|
|
Visualize a 3D point cloud using Matplotlib, with an option to provide |
|
|
per-point RGB or RGBA colors, ensuring equal scaling for the x, y, and z axes. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
points : np.ndarray or torch.Tensor |
|
|
A numpy array (or Tensor) of shape [N, 3] where each row is a 3D point (x, y, z). |
|
|
colors : None, str, or np.ndarray |
|
|
- If None, a default single color ('blue') is used. |
|
|
- If a string, that color will be used for all points. |
|
|
- If a numpy array, it should have shape [N, 3] or [N, 4], where each row |
|
|
corresponds to the color of the matching point in `points`. |
|
|
Values should be in the range [0, 1] if using floats. |
|
|
title : str, optional |
|
|
The title of the plot. Default is 'Point Cloud'. |
|
|
point_size : float, optional |
|
|
The size of the points in the scatter plot. Default is 1. |
|
|
alpha : float, optional |
|
|
The overall alpha (transparency) value for the points. Default is 1.0. |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> import numpy as np |
|
|
>>> # Generate random points |
|
|
>>> pts = np.random.rand(1000, 3) |
|
|
>>> # Generate random colors in [0,1] |
|
|
>>> cols = np.random.rand(1000, 3) |
|
|
>>> visualize_pointcloud(pts, colors=cols, title="Random Point Cloud with Colors") |
|
|
""" |
|
|
|
|
|
|
|
|
if isinstance(points, torch.Tensor): |
|
|
points = points.detach().cpu().numpy() |
|
|
if isinstance(colors, torch.Tensor): |
|
|
colors = colors.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if len(points.shape) > 2: |
|
|
points = points.reshape(-1, 3) |
|
|
if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2: |
|
|
colors = colors.reshape(-1, colors.shape[-1]) |
|
|
|
|
|
|
|
|
if points.shape[1] != 3: |
|
|
raise ValueError("`points` array must have shape [N, 3].") |
|
|
|
|
|
|
|
|
if colors is None: |
|
|
colors = 'blue' |
|
|
elif isinstance(colors, np.ndarray): |
|
|
colors = np.asarray(colors) |
|
|
if colors.shape[0] != points.shape[0]: |
|
|
raise ValueError( |
|
|
"Colors array length must match the number of points." |
|
|
) |
|
|
if colors.shape[1] not in [3, 4]: |
|
|
raise ValueError( |
|
|
"Colors array must have shape [N, 3] or [N, 4]." |
|
|
) |
|
|
|
|
|
|
|
|
x = points[:, 0] |
|
|
y = points[:, 1] |
|
|
z = points[:, 2] |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(8, 6)) |
|
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
|
|
|
|
|
ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha) |
|
|
|
|
|
|
|
|
ax.set_xlabel('X') |
|
|
ax.set_ylabel('Y') |
|
|
ax.set_zlabel('Z') |
|
|
ax.set_title(title) |
|
|
|
|
|
|
|
|
max_range = np.array([x.max() - x.min(), |
|
|
y.max() - y.min(), |
|
|
z.max() - z.min()]).max() / 2.0 |
|
|
mid_x = (x.max() + x.min()) * 0.5 |
|
|
mid_y = (y.max() + y.min()) * 0.5 |
|
|
mid_z = (z.max() + z.min()) * 0.5 |
|
|
|
|
|
ax.set_xlim(mid_x - max_range, mid_x + max_range) |
|
|
ax.set_ylim(mid_y - max_range, mid_y + max_range) |
|
|
ax.set_zlim(mid_z - max_range, mid_z + max_range) |
|
|
|
|
|
|
|
|
ax.view_init(elev=20., azim=30) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
def visualize_depth(depth_image, |
|
|
file_name="rendered_depth.png", |
|
|
visualization_dir="visualization", |
|
|
size=(512, 288)): |
|
|
""" |
|
|
Visualize a depth map as a grayscale image. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
depth_image : np.ndarray |
|
|
A 2D array of depth values. |
|
|
visualization_dir : str |
|
|
The directory to save the visualization image. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
PIL.Image |
|
|
The visualization image. |
|
|
""" |
|
|
|
|
|
depth_min = depth_image.min() |
|
|
depth_max = depth_image.max() |
|
|
print(f"Depth min: {depth_min}, max: {depth_max}") |
|
|
depth_image = np.clip(depth_image, 0, depth_max) |
|
|
depth_vis = (depth_image - depth_min) / (depth_max - depth_min) |
|
|
depth_vis = (depth_vis * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
depth_vis_img = Image.fromarray(depth_vis, mode='L') |
|
|
|
|
|
depth_vis_img = depth_vis_img.resize(size, Image.NEAREST) |
|
|
|
|
|
depth_vis_img.save(os.path.join(visualization_dir, file_name)) |
|
|
|
|
|
return depth_vis_img |
|
|
|
|
|
class Surfel: |
|
|
def __init__(self, position, normal, radius=1.0, color=None): |
|
|
""" |
|
|
position: (x, y, z) |
|
|
normal: (nx, ny, nz) |
|
|
radius: scalar |
|
|
color: (r, g, b) or None |
|
|
""" |
|
|
self.position = position |
|
|
self.normal = normal |
|
|
self.radius = radius |
|
|
self.color = color |
|
|
|
|
|
def __repr__(self): |
|
|
return (f"Surfel(position={self.position}, " |
|
|
f"normal={self.normal}, radius={self.radius}, " |
|
|
f"color={self.color})") |
|
|
|
|
|
|
|
|
|
|
|
class Octree: |
|
|
def __init__(self, points, indices=None, bbox=None, max_points=10): |
|
|
self.points = points |
|
|
if indices is None: |
|
|
indices = np.arange(points.shape[0]) |
|
|
self.indices = indices |
|
|
|
|
|
|
|
|
if bbox is None: |
|
|
min_bound = points.min(axis=0) |
|
|
max_bound = points.max(axis=0) |
|
|
center = (min_bound + max_bound) / 2 |
|
|
half_size = np.max(max_bound - min_bound) / 2 |
|
|
bbox = (center, half_size) |
|
|
self.center, self.half_size = bbox |
|
|
|
|
|
self.children = [] |
|
|
self.max_points = max_points |
|
|
|
|
|
if len(self.indices) > self.max_points: |
|
|
self.subdivide() |
|
|
|
|
|
def subdivide(self): |
|
|
|
|
|
cx, cy, cz = self.center |
|
|
hs = self.half_size / 2 |
|
|
|
|
|
offsets = np.array([[dx, dy, dz] for dx in (-hs, hs) |
|
|
for dy in (-hs, hs) |
|
|
for dz in (-hs, hs)]) |
|
|
for offset in offsets: |
|
|
child_center = self.center + offset |
|
|
child_indices = [] |
|
|
|
|
|
for idx in self.indices: |
|
|
p = self.points[idx] |
|
|
if np.all(np.abs(p - child_center) <= hs): |
|
|
child_indices.append(idx) |
|
|
child_indices = np.array(child_indices) |
|
|
if len(child_indices) > 0: |
|
|
child = Octree(self.points, indices=child_indices, bbox=(child_center, hs), max_points=self.max_points) |
|
|
self.children.append(child) |
|
|
|
|
|
self.indices = None |
|
|
|
|
|
def sphere_intersects_node(self, center, r): |
|
|
|
|
|
diff = np.abs(center - self.center) |
|
|
max_diff = diff - self.half_size |
|
|
max_diff = np.maximum(max_diff, 0) |
|
|
dist_sq = np.sum(max_diff**2) |
|
|
return dist_sq <= r*r |
|
|
|
|
|
def query_ball_point(self, point, r): |
|
|
|
|
|
results = [] |
|
|
if not self.sphere_intersects_node(point, r): |
|
|
return results |
|
|
|
|
|
if len(self.children) == 0: |
|
|
if self.indices is not None: |
|
|
for idx in self.indices: |
|
|
if np.linalg.norm(self.points[idx] - point) <= r: |
|
|
results.append(idx) |
|
|
return results |
|
|
else: |
|
|
for child in self.children: |
|
|
results.extend(child.query_ball_point(point, r)) |
|
|
return results |
|
|
|
|
|
|