| | import numpy as np |
| | import cv2 |
| | import PIL |
| | from PIL import Image |
| | import os |
| | from datetime import datetime |
| | import pdb |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import os |
| | import cv2 |
| | import copy |
| | from scipy.interpolate import UnivariateSpline, interp1d |
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| | import torchvision |
| | from tqdm import tqdm |
| | from pathlib import Path |
| | from typing import Tuple, Optional |
| | import cv2 |
| | import PIL |
| | import numpy |
| | import skimage.io |
| | import torch |
| | import torch.nn.functional as F |
| | from decord import VideoReader, cpu |
| |
|
| | from PIL import Image |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| |
|
| | |
| |
|
| | def resize_frames(frames, resolution): |
| | Transoposed = False |
| | if not torch.is_tensor(frames): |
| | print(" frames are not tensors but numpy arrays ") |
| | if frames.shape[1]==1 or frames.shape[1]==3: |
| | |
| | frames = frames.transpose(0, 2, 3, 1) |
| | Transoposed = True |
| | resized_frames = [] |
| | for frame in frames: |
| | img = Image.fromarray((frame * 255).astype('uint8')) |
| | img = img.resize((resolution[1], resolution[0]), Image.BILINEAR) |
| | img = np.array(img).astype("float32") / 255.0 |
| | resized_frames.append(img) |
| | resized_frames = np.stack(resized_frames, axis=0) |
| | if Transoposed: |
| | resized_frames = resized_frames.transpose(0, 3, 1, 2) |
| | else: |
| | if frames.shape[-1]==1 or frames.shape[-1]==3: |
| | |
| | frames = frames.permute(0, 3, 2, 1) |
| | Transoposed = True |
| | resized_frames = F.interpolate(frames, size=resolution, mode='bilinear', align_corners=True) |
| | if Transoposed: |
| | resized_frames = resized_frames.permute(0, 2, 3, 1) |
| | return resized_frames |
| |
|
| |
|
| | def read_video_frames_av(video_path, process_length, stride, sample_size, max_res, dataset="open"): |
| | |
| | |
| |
|
| | div = 16 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if dataset == "open": |
| | |
| | print("==> processing video:", video_path) |
| | vid = VideoReader(video_path, ctx=cpu(0)) |
| | print("==> original video shape:", (len(vid), *vid.get_batch([0]).shape[1:])) |
| | original_height, original_width = vid.get_batch([0]).shape[1:3] |
| | original_height, original_width = sample_size[0], sample_size[1] |
| |
|
| | height = int(original_height / div) * div |
| | width = int(original_width / div) * div |
| |
|
| | if max(height, width) > max_res: |
| | scale = max_res / max(original_height, original_width) |
| | |
| |
|
| | height = round(original_height * scale / div) * div |
| | width = round(original_width * scale / div) * div |
| | |
| | |
| | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) |
| | frames_idx = list(range(0, len(vid), stride)) |
| | print(f"==> downsampled shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}, with stride: {stride}") |
| | if process_length != -1 and process_length < len(frames_idx): |
| | frames_idx = frames_idx[:process_length] |
| | print(f"==> final processing shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}") |
| | frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0 |
| | return frames |
| |
|
| | def read_video_frames(video_path, process_length, stride, max_res, dataset="open"): |
| | def is_image(path): |
| | return any(path.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.bmp']) |
| | |
| | if is_image(video_path): |
| | print("==> Detected image. Loading as single-frame video:", video_path) |
| | img = Image.open(video_path).convert("RGB") |
| | original_width = img.width |
| | original_height = img.height |
| |
|
| | height = round(original_height / 64) * 64 |
| | width = round(original_width / 64) * 64 |
| |
|
| | if max(height, width) > max_res: |
| | scale = max_res / max(original_height, original_width) |
| | height = round(original_height * scale / 64) * 64 |
| | width = round(original_width * scale / 64) * 64 |
| |
|
| | img = img.resize((width, height), Image.BICUBIC) |
| | img = np.array(img).astype("float32") / 255.0 |
| | frames = img[None, ...] |
| | print(f"==> image shape: {frames.shape}") |
| | return frames |
| |
|
| | if dataset == "open": |
| | print("==> processing video:", video_path) |
| | vid = VideoReader(video_path, ctx=cpu(0)) |
| | print("==> original video shape:", (len(vid), *vid.get_batch([0]).shape[1:])) |
| |
|
| | original_height, original_width = vid.get_batch([0]).shape[1:3] |
| |
|
| | height = round(original_height / 64) * 64 |
| | width = round(original_width / 64) * 64 |
| |
|
| | if max(height, width) > max_res: |
| | scale = max_res / max(original_height, original_width) |
| | height = round(original_height * scale / 64) * 64 |
| | width = round(original_width * scale / 64) * 64 |
| |
|
| | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) |
| |
|
| | frames_idx = list(range(0, len(vid), stride)) |
| | print(f"==> downsampled shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}, with stride: {stride}") |
| | if process_length != -1 and process_length < len(frames_idx): |
| | frames_idx = frames_idx[:process_length] |
| | print(f"==> final processing shape: {(len(frames_idx), *vid.get_batch([0]).shape[1:])}") |
| | frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0 |
| |
|
| | return frames |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | def save_video(data, images_path, folder=None, fps=8): |
| | if isinstance(data, np.ndarray): |
| | tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8) |
| | elif isinstance(data, torch.Tensor): |
| | tensor_data = (data.detach().cpu() * 255).to(torch.uint8) |
| | elif isinstance(data, list): |
| | folder = [folder] * len(data) |
| | images = [ |
| | np.array(Image.open(os.path.join(folder_name, path))) |
| | for folder_name, path in zip(folder, data) |
| | ] |
| | stacked_images = np.stack(images, axis=0) |
| | tensor_data = torch.from_numpy(stacked_images).to(torch.uint8) |
| | torchvision.io.write_video( |
| | images_path, tensor_data, fps=fps, video_codec='h264', options={'crf': '10'} |
| | ) |
| |
|
| |
|
| | def sphere2pose(c2ws_input, theta, phi, r, device, x=None, y=None): |
| | c2ws = copy.deepcopy(c2ws_input) |
| | |
| |
|
| | |
| | c2ws[:, 2, 3] -= r |
| | if x is not None: |
| | c2ws[:, 1, 3] += y |
| | if y is not None: |
| | c2ws[:, 0, 3] -= x |
| |
|
| | theta = torch.deg2rad(torch.tensor(theta)).to(device) |
| | sin_value_x = torch.sin(theta) |
| | cos_value_x = torch.cos(theta) |
| | rot_mat_x = ( |
| | torch.tensor( |
| | [ |
| | [1, 0, 0, 0], |
| | [0, cos_value_x, -sin_value_x, 0], |
| | [0, sin_value_x, cos_value_x, 0], |
| | [0, 0, 0, 1], |
| | ] |
| | ) |
| | .unsqueeze(0) |
| | .repeat(c2ws.shape[0], 1, 1) |
| | .to(device) |
| | ) |
| |
|
| | phi = torch.deg2rad(torch.tensor(phi)).to(device) |
| | sin_value_y = torch.sin(phi) |
| | cos_value_y = torch.cos(phi) |
| | rot_mat_y = ( |
| | torch.tensor( |
| | [ |
| | [cos_value_y, 0, sin_value_y, 0], |
| | [0, 1, 0, 0], |
| | [-sin_value_y, 0, cos_value_y, 0], |
| | [0, 0, 0, 1], |
| | ] |
| | ) |
| | .unsqueeze(0) |
| | .repeat(c2ws.shape[0], 1, 1) |
| | .to(device) |
| | ) |
| |
|
| | c2ws = torch.matmul(rot_mat_x, c2ws) |
| | c2ws = torch.matmul(rot_mat_y, c2ws) |
| | |
| | return c2ws |
| |
|
| |
|
| | def generate_traj_specified(c2ws_anchor, theta, phi, d_r, d_x, d_y, frame, device, opts): |
| | |
| |
|
| | camera_x = getattr(opts, "init_dx", 0.0) |
| | camera_y = getattr(opts, "init_dy", 0.0) |
| | camera_z = getattr(opts, "init_dz", 0.0) |
| | camera_theta = getattr(opts, "init_theta", 0.0) |
| | camera_phi = getattr(opts, "init_phi", 0.0) |
| |
|
| | thetas = np.linspace(0, theta, frame) + camera_theta |
| | phis = np.linspace(0, phi, frame) + camera_phi |
| | rs = np.linspace(0, d_r, frame)+camera_z |
| | xs = np.linspace(0, d_x, frame)+camera_x |
| | ys = np.linspace(0, d_y, frame)+camera_y |
| |
|
| | |
| | c2ws_list = [] |
| | c2ws_list.append(copy.deepcopy(c2ws_anchor)) |
| |
|
| | for th, ph, r, x, y in zip(thetas, phis, rs, xs, ys): |
| | c2w_new = sphere2pose( |
| | c2ws_anchor, |
| | np.float32(th), |
| | np.float32(ph), |
| | np.float32(r), |
| | device, |
| | np.float32(x), |
| | np.float32(y), |
| | ) |
| | c2ws_list.append(c2w_new) |
| | c2ws = torch.cat(c2ws_list, dim=0) |
| | return c2ws |
| |
|
| | def generate_traj_specified_frame(c2ws_anchor, theta, phi, d_r, d_x, d_y, frame, device, opts): |
| | |
| |
|
| | camera_x = getattr(opts, "init_dx", 0.0) |
| | camera_y = getattr(opts, "init_dy", 0.0) |
| | camera_z = getattr(opts, "init_dz", 0.0) |
| | camera_theta = getattr(opts, "init_theta", 0.0) |
| | camera_phi = getattr(opts, "init_phi", 0.0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | thetas = np.full(frame, theta/(frame)) + camera_theta |
| | phis = np.full(frame, phi/(frame)) + camera_phi |
| | rs = np.full(frame, d_r/(frame)) + camera_z |
| | xs = np.full(frame, d_x/(frame)) + camera_x |
| | ys = np.full(frame, d_y/(frame)) + camera_y |
| |
|
| | |
| | c2ws_list = [] |
| | c2ws_list.append(copy.deepcopy(c2ws_anchor)) |
| |
|
| | for th, ph, r, x, y in zip(thetas, phis, rs, xs, ys): |
| | c2w_new = sphere2pose( |
| | c2ws_anchor, |
| | np.float32(th), |
| | np.float32(ph), |
| | np.float32(r), |
| | device, |
| | np.float32(x), |
| | np.float32(y), |
| | ) |
| | c2ws_list.append(c2w_new) |
| | c2ws = torch.cat(c2ws_list, dim=0) |
| | return c2ws |
| |
|
| | def generate_traj_specified_fast(c2ws_anchor, theta, phi, d_r, d_x, d_y, frame, device): |
| | half = frame // 3 |
| |
|
| | thetas = np.linspace(0, theta, half) |
| | phis = np.linspace(0, phi, half) |
| | rs = np.linspace(0, d_r, half) |
| | xs = np.linspace(0, d_x, half) |
| | ys = np.linspace(0, d_y, half) |
| |
|
| | c2ws_list = [] |
| |
|
| |
|
| | for th, ph, r, x, y in zip(thetas, phis, rs, xs, ys): |
| | c2w_new = sphere2pose( |
| | c2ws_anchor, |
| | np.float32(th), |
| | np.float32(ph), |
| | np.float32(r), |
| | device, |
| | np.float32(x), |
| | np.float32(y), |
| | ) |
| | c2ws_list.append(c2w_new) |
| |
|
| | last_c2w = c2ws_list[-1] |
| | for _ in range(frame - half): |
| | c2ws_list.append(last_c2w.clone()) |
| |
|
| | c2ws = torch.cat(c2ws_list, dim=0) |
| | return c2ws |
| |
|
| |
|
| | def txt_interpolation(input_list, n, mode='smooth'): |
| | x = np.linspace(0, 1, len(input_list)) |
| | if mode == 'smooth': |
| | f = UnivariateSpline(x, input_list, k=3) |
| | elif mode == 'linear': |
| | f = interp1d(x, input_list) |
| | else: |
| | raise KeyError(f"Invalid txt interpolation mode: {mode}") |
| | xnew = np.linspace(0, 1, n) |
| | ynew = f(xnew) |
| | return ynew |
| |
|
| |
|
| | def generate_traj_txt(c2ws_anchor, phi, theta, r, frame, device, x = None, y = None): |
| | |
| | """ |
| | The camera coordinate sysmte in COLMAP is right-down-forward |
| | Pytorch3D is left-up-forward |
| | """ |
| |
|
| | if len(phi) > 3: |
| | phis = txt_interpolation(phi, frame, mode='smooth') |
| | phis[0] = phi[0] |
| | phis[-1] = phi[-1] |
| | else: |
| | phis = txt_interpolation(phi, frame, mode='linear') |
| |
|
| | if len(theta) > 3: |
| | thetas = txt_interpolation(theta, frame, mode='smooth') |
| | thetas[0] = theta[0] |
| | thetas[-1] = theta[-1] |
| | else: |
| | thetas = txt_interpolation(theta, frame, mode='linear') |
| |
|
| | if len(r) > 3: |
| | rs = txt_interpolation(r, frame, mode='smooth') |
| | rs[0] = r[0] |
| | rs[-1] = r[-1] |
| | else: |
| | rs = txt_interpolation(r, frame, mode='linear') |
| | |
| |
|
| | if len(x) > 3: |
| | xs = txt_interpolation(x, frame, mode='smooth') |
| | xs[0] = x[0] |
| | xs[-1] = x[-1] |
| | else: |
| | xs = txt_interpolation(x, frame, mode='linear') |
| | |
| | if len(y) > 3: |
| | ys = txt_interpolation(y, frame, mode='smooth') |
| | ys[0] = y[0] |
| | ys[-1] = y[-1] |
| | else: |
| | ys = txt_interpolation(y, frame, mode='linear') |
| | |
| | |
| | |
| |
|
| | c2ws_list = [] |
| | c2ws_list.append(copy.deepcopy(c2ws_anchor)) |
| |
|
| | for th, ph, r, x_i, y_i in zip(thetas, phis, rs, xs, ys): |
| | c2w_new = sphere2pose( |
| | c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device, np.float32(x_i), np.float32(y_i) |
| | ) |
| | c2ws_list.append(c2w_new) |
| | c2ws = torch.cat(c2ws_list, dim=0) |
| | return c2ws |
| |
|
| |
|
| | class Warper: |
| | def __init__(self, resolution: tuple = None, device: str = 'gpu0'): |
| | self.resolution = resolution |
| | self.device = self.get_device(device) |
| | self.dtype = torch.float32 |
| | return |
| |
|
| | def forward_warp( |
| | self, |
| | frame1: torch.Tensor, |
| | mask1: Optional[torch.Tensor], |
| | depth1: torch.Tensor, |
| | transformation1: torch.Tensor, |
| | transformation2: torch.Tensor, |
| | intrinsic1: torch.Tensor, |
| | intrinsic2: Optional[torch.Tensor], |
| | mask=False, |
| | twice=False |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using |
| | bilinear splatting. |
| | All arrays should be torch tensors with batch dimension and channel first |
| | :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling |
| | bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting() |
| | method accordingly. |
| | :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional |
| | :param depth1: (b, 1, h, w) |
| | :param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1] |
| | :param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1] |
| | :param intrinsic1: (b, 3, 3) camera intrinsic matrix |
| | :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional |
| | """ |
| | if self.resolution is not None: |
| | assert frame1.shape[2:4] == self.resolution |
| | b, c, h, w = frame1.shape |
| | if mask1 is None: |
| | mask1 = torch.ones(size=(b, 1, h, w)).to(frame1) |
| | if intrinsic2 is None: |
| | intrinsic2 = intrinsic1.clone() |
| |
|
| | assert frame1.shape == (b, 3, h, w) |
| | assert mask1.shape == (b, 1, h, w) |
| | assert depth1.shape == (b, 1, h, w) |
| | assert transformation1.shape == (b, 4, 4) |
| | assert transformation2.shape == (b, 4, 4) |
| | assert intrinsic1.shape == (b, 3, 3) |
| | assert intrinsic2.shape == (b, 3, 3) |
| |
|
| | frame1 = frame1.to(self.device).to(self.dtype) |
| | mask1 = mask1.to(self.device).to(self.dtype) |
| | depth1 = depth1.to(self.device).to(self.dtype) |
| | transformation1 = transformation1.to(self.device).to(self.dtype) |
| | transformation2 = transformation2.to(self.device).to(self.dtype) |
| | intrinsic1 = intrinsic1.to(self.device).to(self.dtype) |
| | intrinsic2 = intrinsic2.to(self.device).to(self.dtype) |
| |
|
| | |
| | trans_points1 = self.compute_transformed_points( |
| | depth1, transformation1, transformation2, intrinsic1, intrinsic2 |
| | ) |
| |
|
| | trans_coordinates = ( |
| | trans_points1[:, :, :, :2, 0] / trans_points1[:, :, :, 2:3, 0] |
| | ) |
| |
|
| |
|
| | trans_depth1 = trans_points1[:, :, :, 2, 0] |
| | grid = self.create_grid(b, h, w).to(trans_coordinates) |
| | flow12 = trans_coordinates.permute(0, 3, 1, 2) - grid |
| |
|
| | if not twice: |
| | warped_frame2, mask2 = self.bilinear_splatting( |
| | frame1, mask1, trans_depth1, flow12, None, is_image=True |
| | ) |
| | if mask: |
| | warped_frame2, mask2 = self.clean_points(warped_frame2, mask2) |
| | return warped_frame2, mask2, None, flow12 |
| |
|
| | else: |
| | warped_frame2, mask2 = self.bilinear_splatting( |
| | frame1, mask1, trans_depth1, flow12, None, is_image=True |
| | ) |
| | warped_flow, _ = self.bilinear_splatting( |
| | flow12, mask1, trans_depth1, flow12, None, is_image=False |
| | ) |
| | twice_warped_frame1, _ = self.bilinear_splatting( |
| | warped_frame2, |
| | mask2, |
| | depth1.squeeze(1), |
| | -warped_flow, |
| | None, |
| | is_image=True, |
| | ) |
| |
|
| | return twice_warped_frame1, warped_frame2, None, None |
| | |
| | def compute_transformed_points( |
| | self, |
| | depth1: torch.Tensor, |
| | transformation1: torch.Tensor, |
| | transformation2: torch.Tensor, |
| | intrinsic1: torch.Tensor, |
| | intrinsic2: Optional[torch.Tensor], |
| | ): |
| | """ |
| | Computes transformed position for each pixel location |
| | """ |
| | if self.resolution is not None: |
| | assert depth1.shape[2:4] == self.resolution |
| | b, _, h, w = depth1.shape |
| | if intrinsic2 is None: |
| | intrinsic2 = intrinsic1.clone() |
| | transformation = torch.bmm( |
| | transformation2, torch.linalg.inv(transformation1) |
| | ) |
| |
|
| | x1d = torch.arange(0, w)[None] |
| | y1d = torch.arange(0, h)[:, None] |
| | x2d = x1d.repeat([h, 1]).to(depth1) |
| | y2d = y1d.repeat([1, w]).to(depth1) |
| | ones_2d = torch.ones(size=(h, w)).to(depth1) |
| | ones_4d = ones_2d[None, :, :, None, None].repeat( |
| | [b, 1, 1, 1, 1] |
| | ) |
| | pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[ |
| | None, :, :, :, None |
| | ] |
| |
|
| | intrinsic1_inv = torch.linalg.inv(intrinsic1) |
| | intrinsic1_inv_4d = intrinsic1_inv[:, None, None] |
| | intrinsic2_4d = intrinsic2[:, None, None] |
| | depth_4d = depth1[:, 0][:, :, :, None, None] |
| | trans_4d = transformation[:, None, None] |
| |
|
| | unnormalized_pos = torch.matmul( |
| | intrinsic1_inv_4d, pos_vectors_homo |
| | ) |
| | world_points = depth_4d * unnormalized_pos |
| | world_points_homo = torch.cat([world_points, ones_4d], dim=3) |
| | trans_world_homo = torch.matmul(trans_4d, world_points_homo) |
| | trans_world = trans_world_homo[:, :, :, :3] |
| | trans_norm_points = torch.matmul(intrinsic2_4d, trans_world) |
| | return trans_norm_points |
| | def compute_transformed_points_frame( |
| | self, |
| | depth1: torch.Tensor, |
| | transformation1: torch.Tensor, |
| | transformation2: torch.Tensor, |
| | intrinsic1: torch.Tensor, |
| | intrinsic2: Optional[torch.Tensor], |
| | grid: Optional[torch.Tensor] = None, |
| | ): |
| | """ |
| | Computes transformed positions for each pixel location or provided grid. |
| | |
| | Args: |
| | depth1: (b, 1, h, w) depth map |
| | transformation1: (b, 4, 4) camera-to-world transform of source |
| | transformation2: (b, 4, 4) camera-to-world transform of target |
| | intrinsic1: (b, 3, 3) intrinsics of source |
| | intrinsic2: (b, 3, 3) intrinsics of target (optional) |
| | grid: (2, h, w), where grid[0] = x coords, grid[1] = y coords |
| | """ |
| | if self.resolution is not None: |
| | assert depth1.shape[2:4] == self.resolution |
| | b, _, h, w = depth1.shape |
| | if intrinsic2 is None: |
| | intrinsic2 = intrinsic1.clone() |
| |
|
| | |
| | transformation = torch.bmm( |
| | transformation2, torch.linalg.inv(transformation1) |
| | ) |
| |
|
| | |
| | if grid is None: |
| | x1d = torch.arange(0, w, device=depth1.device)[None] |
| | y1d = torch.arange(0, h, device=depth1.device)[:, None] |
| | x2d = x1d.repeat([h, 1]) |
| | y2d = y1d.repeat([1, w]) |
| | |
| | else: |
| | x2d = grid[0][0].to(depth1.device) |
| | y2d = grid[0][1].to(depth1.device) |
| |
|
| | |
| | ones_2d = torch.ones(size=(h, w), device=depth1.device) |
| | pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[ |
| | None, :, :, :, None |
| | ] |
| |
|
| | |
| | ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) |
| | intrinsic1_inv = torch.linalg.inv(intrinsic1) |
| | intrinsic1_inv_4d = intrinsic1_inv[:, None, None] |
| | intrinsic2_4d = intrinsic2[:, None, None] |
| | depth_4d = depth1[:, 0][:, :, :, None, None] |
| | trans_4d = transformation[:, None, None] |
| |
|
| | |
| | unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) |
| | world_points = depth_4d * unnormalized_pos |
| | world_points_homo = torch.cat([world_points, ones_4d], dim=3) |
| |
|
| | |
| | trans_world_homo = torch.matmul(trans_4d, world_points_homo) |
| | trans_world = trans_world_homo[:, :, :, :3] |
| | trans_norm_points = torch.matmul(intrinsic2_4d, trans_world) |
| |
|
| | return trans_norm_points |
| |
|
| |
|
| | def bilinear_splatting( |
| | self, |
| | frame1: torch.Tensor, |
| | mask1: Optional[torch.Tensor], |
| | depth1: torch.Tensor, |
| | flow12: torch.Tensor, |
| | flow12_mask: Optional[torch.Tensor], |
| | is_image: bool = False, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Bilinear splatting |
| | :param frame1: (b,c,h,w) |
| | :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional |
| | :param depth1: (b,1,h,w) |
| | :param flow12: (b,2,h,w) |
| | :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional |
| | :param is_image: if true, output will be clipped to (-1,1) range |
| | :return: warped_frame2: (b,c,h,w) |
| | mask2: (b,1,h,w): 1 for known and 0 for unknown |
| | """ |
| | if self.resolution is not None: |
| | assert frame1.shape[2:4] == self.resolution |
| | b, c, h, w = frame1.shape |
| | if mask1 is None: |
| | mask1 = torch.ones(size=(b, 1, h, w)).to(frame1) |
| | if flow12_mask is None: |
| | flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12) |
| | grid = self.create_grid(b, h, w).to(frame1) |
| | trans_pos = flow12 + grid |
| |
|
| | trans_pos_offset = trans_pos + 1 |
| | trans_pos_floor = torch.floor(trans_pos_offset).long() |
| | trans_pos_ceil = torch.ceil(trans_pos_offset).long() |
| | trans_pos_offset = torch.stack( |
| | [ |
| | torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), |
| | torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1), |
| | ], |
| | dim=1, |
| | ) |
| | trans_pos_floor = torch.stack( |
| | [ |
| | torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), |
| | torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1), |
| | ], |
| | dim=1, |
| | ) |
| | trans_pos_ceil = torch.stack( |
| | [ |
| | torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), |
| | torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( |
| | 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) |
| | ) |
| | prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( |
| | 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) |
| | ) |
| | prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( |
| | 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) |
| | ) |
| | prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( |
| | 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) |
| | ) |
| |
|
| | sat_depth1 = torch.clamp(depth1, min=0, max=1000) |
| | log_depth1 = torch.log(1 + sat_depth1) |
| | depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50) |
| |
|
| | weight_nw = torch.moveaxis( |
| | prox_weight_nw * mask1 * flow12_mask / depth_weights.unsqueeze(1), |
| | [0, 1, 2, 3], |
| | [0, 3, 1, 2], |
| | ) |
| | weight_sw = torch.moveaxis( |
| | prox_weight_sw * mask1 * flow12_mask / depth_weights.unsqueeze(1), |
| | [0, 1, 2, 3], |
| | [0, 3, 1, 2], |
| | ) |
| | weight_ne = torch.moveaxis( |
| | prox_weight_ne * mask1 * flow12_mask / depth_weights.unsqueeze(1), |
| | [0, 1, 2, 3], |
| | [0, 3, 1, 2], |
| | ) |
| | weight_se = torch.moveaxis( |
| | prox_weight_se * mask1 * flow12_mask / depth_weights.unsqueeze(1), |
| | [0, 1, 2, 3], |
| | [0, 3, 1, 2], |
| | ) |
| |
|
| | warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=torch.float32).to( |
| | frame1 |
| | ) |
| | warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=torch.float32).to( |
| | frame1 |
| | ) |
| |
|
| | frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2]) |
| | batch_indices = torch.arange(b)[:, None, None].to(frame1.device) |
| | warped_frame.index_put_( |
| | (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), |
| | frame1_cl * weight_nw, |
| | accumulate=True, |
| | ) |
| | warped_frame.index_put_( |
| | (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), |
| | frame1_cl * weight_sw, |
| | accumulate=True, |
| | ) |
| | warped_frame.index_put_( |
| | (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), |
| | frame1_cl * weight_ne, |
| | accumulate=True, |
| | ) |
| | warped_frame.index_put_( |
| | (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), |
| | frame1_cl * weight_se, |
| | accumulate=True, |
| | ) |
| |
|
| | warped_weights.index_put_( |
| | (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), |
| | weight_nw, |
| | accumulate=True, |
| | ) |
| | warped_weights.index_put_( |
| | (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), |
| | weight_sw, |
| | accumulate=True, |
| | ) |
| | warped_weights.index_put_( |
| | (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), |
| | weight_ne, |
| | accumulate=True, |
| | ) |
| | warped_weights.index_put_( |
| | (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), |
| | weight_se, |
| | accumulate=True, |
| | ) |
| |
|
| | warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1]) |
| | warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1]) |
| | cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1] |
| | cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1] |
| |
|
| | mask = cropped_weights > 0 |
| | zero_value = -1 if is_image else 0 |
| | zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device) |
| | warped_frame2 = torch.where( |
| | mask, cropped_warped_frame / cropped_weights, zero_tensor |
| | ) |
| | mask2 = mask.to(frame1) |
| |
|
| | if is_image: |
| | assert warped_frame2.min() >= -1.1 |
| | assert warped_frame2.max() <= 1.1 |
| | warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1) |
| | return warped_frame2, mask2 |
| |
|
| | def clean_points(self, warped_frame2, mask2): |
| | warped_frame2 = (warped_frame2 + 1.0) / 2.0 |
| | mask = 1 - mask2 |
| | mask[mask < 0.5] = 0 |
| | mask[mask >= 0.5] = 1 |
| | mask = mask.squeeze(0).repeat(3, 1, 1).permute(1, 2, 0) * 255.0 |
| | mask = mask.cpu().numpy() |
| | kernel = numpy.ones((5, 5), numpy.uint8) |
| | mask_erosion = cv2.dilate(numpy.array(mask), kernel, iterations=1) |
| | mask_erosion = PIL.Image.fromarray(numpy.uint8(mask_erosion)) |
| | mask_erosion_ = numpy.array(mask_erosion) / 255.0 |
| | mask_erosion_[mask_erosion_ < 0.5] = 0 |
| | mask_erosion_[mask_erosion_ >= 0.5] = 1 |
| | mask_new = ( |
| | torch.from_numpy(mask_erosion_) |
| | .permute(2, 0, 1) |
| | .unsqueeze(0) |
| | .to(self.device) |
| | ) |
| | warped_frame2 = warped_frame2 * (1 - mask_new) |
| | return warped_frame2 * 2.0 - 1.0, 1 - mask_new[:, 0:1, :, :] |
| |
|
| | @staticmethod |
| | def create_grid(b, h, w): |
| | x_1d = torch.arange(0, w)[None] |
| | y_1d = torch.arange(0, h)[:, None] |
| | x_2d = x_1d.repeat([h, 1]) |
| | y_2d = y_1d.repeat([1, w]) |
| | grid = torch.stack([x_2d, y_2d], dim=0) |
| | batch_grid = grid[None].repeat([b, 1, 1, 1]) |
| | return batch_grid |
| |
|
| | @staticmethod |
| | def read_image(path: Path) -> torch.Tensor: |
| | image = skimage.io.imread(path.as_posix()) |
| | return image |
| |
|
| | @staticmethod |
| | def read_depth(path: Path) -> torch.Tensor: |
| | if path.suffix == '.png': |
| | depth = skimage.io.imread(path.as_posix()) |
| | elif path.suffix == '.npy': |
| | depth = numpy.load(path.as_posix()) |
| | elif path.suffix == '.npz': |
| | with numpy.load(path.as_posix()) as depth_data: |
| | depth = depth_data['depth'] |
| | else: |
| | raise RuntimeError(f'Unknown depth format: {path.suffix}') |
| | return depth |
| |
|
| | @staticmethod |
| | def camera_intrinsic_transform( |
| | capture_width=1920, capture_height=1080, patch_start_point: tuple = (0, 0) |
| | ): |
| | start_y, start_x = patch_start_point |
| | camera_intrinsics = numpy.eye(4) |
| | camera_intrinsics[0, 0] = 2100 |
| | camera_intrinsics[0, 2] = capture_width / 2.0 - start_x |
| | camera_intrinsics[1, 1] = 2100 |
| | camera_intrinsics[1, 2] = capture_height / 2.0 - start_y |
| | return camera_intrinsics |
| |
|
| | @staticmethod |
| | def get_device(device: str): |
| | """ |
| | Returns torch device object |
| | :param device: cpu/gpu0/gpu1 |
| | :return: |
| | """ |
| | if device == 'cpu': |
| | device = torch.device('cpu') |
| | elif device.startswith('gpu') and torch.cuda.is_available(): |
| | gpu_num = int(device[3:]) |
| | device = torch.device(f'cuda:{gpu_num}') |
| | else: |
| | device = torch.device('cpu') |
| | return device |
| |
|