|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as transforms |
|
|
from packaging import version as pver |
|
|
import random |
|
|
import json |
|
|
from typing import Union, Callable, Tuple |
|
|
import OpenEXR |
|
|
import Imath |
|
|
|
|
|
class ImageTransform: |
|
|
def __init__(self, crop_size, sample_size, max_crop, use_flip=False, ): |
|
|
self.use_flip = use_flip |
|
|
self.crop_size = crop_size |
|
|
self.max_crop = max_crop |
|
|
self.sample_size = sample_size |
|
|
self.crop_transform = transforms.CenterCrop(crop_size) if crop_size else lambda x: x |
|
|
self.resize_transform = ( |
|
|
transforms.Resize(sample_size) if sample_size else lambda x: x |
|
|
) |
|
|
self.resize_transform_depth = ( |
|
|
transforms.Resize(sample_size, interpolation=transforms.InterpolationMode.NEAREST) if sample_size else lambda x: x |
|
|
) |
|
|
|
|
|
def preprocess_images(self, images, depths=None): |
|
|
|
|
|
|
|
|
video = images |
|
|
|
|
|
if self.use_flip: |
|
|
assert False |
|
|
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames) |
|
|
else: |
|
|
flip_flag = torch.zeros( |
|
|
images.shape[0], dtype=torch.bool, device=video.device |
|
|
) |
|
|
|
|
|
ori_h, ori_w = video.shape[-2:] |
|
|
if self.max_crop: |
|
|
|
|
|
crop_ratio = min(ori_h/self.crop_size[0], ori_w/self.crop_size[1]) |
|
|
new_crop_size = (int(self.crop_size[0]*crop_ratio), int(self.crop_size[1]*crop_ratio)) |
|
|
self.crop_transform = transforms.CenterCrop(new_crop_size) |
|
|
|
|
|
|
|
|
video = self.crop_transform(video) |
|
|
if depths is not None: |
|
|
depths = self.crop_transform(depths) |
|
|
|
|
|
new_h, new_w = video.shape[-2:] |
|
|
|
|
|
shift = ((new_w - ori_w) / 2, (new_h - ori_h) / 2) |
|
|
|
|
|
|
|
|
ori_h, ori_w = video.shape[-2:] |
|
|
|
|
|
video = self.resize_transform(video) |
|
|
if depths is not None: |
|
|
depths = self.resize_transform_depth(depths) |
|
|
new_h, new_w = video.shape[-2:] |
|
|
scale = (new_w/ori_w, new_h/ori_h) |
|
|
|
|
|
if self.use_flip: |
|
|
video = self.flip_transform(video, flip_flag) |
|
|
if depths is not None: |
|
|
depths = self.flip_transform(depths) |
|
|
|
|
|
|
|
|
return video, depths, shift, scale, flip_flag |
|
|
|
|
|
def apply_img_transform(self, i, j, shift, scale): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
i = (i + shift[0]) * scale[0] |
|
|
j = (j + shift[1]) * scale[1] |
|
|
return i, j |
|
|
|
|
|
|
|
|
|
|
|
def custom_meshgrid(*args): |
|
|
|
|
|
if pver.parse(torch.__version__) < pver.parse('1.10'): |
|
|
return torch.meshgrid(*args) |
|
|
else: |
|
|
return torch.meshgrid(*args, indexing='ij') |
|
|
|
|
|
def get_grid_uvs(batch_shape, H, W, device, dtype=None, flip_flag=None, nh=None, nw=None, margin=0): |
|
|
if dtype is None: dtype = torch.float32 |
|
|
if nh is None: nh = H |
|
|
if nw is None: nw = W |
|
|
|
|
|
|
|
|
|
|
|
B, V = batch_shape |
|
|
|
|
|
j, i = custom_meshgrid( |
|
|
torch.linspace(0, H - 1, nh, device=device, dtype=dtype), |
|
|
torch.linspace(0, W - 1, nw, device=device, dtype=dtype), |
|
|
) |
|
|
i = i.reshape([1, 1, nh * nw]).expand([B, V, nh * nw]) + 0.5 |
|
|
j = j.reshape([1, 1, nh * nw]).expand([B, V, nh * nw]) + 0.5 |
|
|
|
|
|
if margin != 0: |
|
|
marginw = 1-2*margin |
|
|
i = marginw * i + margin * W |
|
|
j = marginw * j + margin * H |
|
|
|
|
|
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 |
|
|
if n_flip > 0: |
|
|
j_flip, i_flip = custom_meshgrid( |
|
|
torch.linspace(0, H - 1, nh, device=device, dtype=dtype), |
|
|
torch.linspace(W - 1, 0, nw, device=device, dtype=dtype) |
|
|
) |
|
|
i_flip = i_flip.reshape([1, 1, nh * nw]).expand(B, 1, nh * nw) + 0.5 |
|
|
j_flip = j_flip.reshape([1, 1, nh * nw]).expand(B, 1, nh * nw) + 0.5 |
|
|
i[:, flip_flag, ...] = i_flip |
|
|
j[:, flip_flag, ...] = j_flip |
|
|
return i,j |
|
|
|
|
|
|
|
|
def get_rays_from_uvs(i,j,K,c2w): |
|
|
fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
|
|
|
zs = torch.ones_like(i) |
|
|
xs = (i - cx) / fx * zs |
|
|
ys = (j - cy) / fy * zs |
|
|
zs = zs.expand_as(ys) |
|
|
|
|
|
directions = torch.stack((xs, ys, zs), dim=-1) |
|
|
directions = directions / directions.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
|
|
rays_o = c2w[..., :3, 3] |
|
|
rays_o = rays_o[:, :, None].expand_as(rays_d) |
|
|
return rays_o, rays_d |
|
|
|
|
|
def project_to_uvs(pts,K,c2w): |
|
|
w2c = torch.linalg.inv(c2w) |
|
|
cam_pts = torch.einsum("...ij,...vj->...vi", w2c[...,:3,:3], pts) + w2c[..., None, :3, 3] |
|
|
|
|
|
fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
|
|
|
xs = cam_pts[...,0] |
|
|
ys = cam_pts[...,1] |
|
|
zs = cam_pts[...,2] |
|
|
|
|
|
|
|
|
us = (fx * xs/zs) + cx |
|
|
vs = (fy * ys/zs) + cy |
|
|
uvs = torch.stack([us,vs], dim=-1) |
|
|
return uvs, zs |
|
|
|
|
|
def get_rays(K, c2w, H, W, device, flip_flag=None, nh=None, nw=None): |
|
|
i,j = get_grid_uvs(K.shape[:2], H=H, W=W, dtype=K.dtype, device=device, flip_flag=flip_flag, nh=nh, nw=nw) |
|
|
return get_rays_from_uvs(i,j,K,c2w) |
|
|
|
|
|
|
|
|
def ray_condition(K, c2w, H, W, device, flip_flag=None, get_batch_index=True): |
|
|
batch_shape = K.shape[:2] |
|
|
B, V = batch_shape |
|
|
rays_o, rays_d = get_rays(K, c2w, H, W, device, flip_flag=flip_flag) |
|
|
rays_dxo = torch.cross(rays_o, rays_d, dim=-1) |
|
|
plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
|
|
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6).permute(0, 1, 4, 2, 3).contiguous() |
|
|
rays_o = rays_o.reshape(B, c2w.shape[1], H, W, 3).permute(0, 1, 4, 2, 3).contiguous() |
|
|
rays_d = rays_d.reshape(B, c2w.shape[1], H, W, 3).permute(0, 1, 4, 2, 3).contiguous() |
|
|
if get_batch_index: |
|
|
plucker = plucker[0] |
|
|
rays_o = rays_o[0] |
|
|
rays_d = rays_d[0] |
|
|
return plucker, rays_o, rays_d |
|
|
|
|
|
def mirror_frame_indices(sampling_num_frames: int, total_num_frames: int, video_mirror_clip_length: int = None, stride: int = 1, start_index: int = None, return_target: bool = False): |
|
|
if video_mirror_clip_length is None: |
|
|
video_mirror_clip_length = total_num_frames |
|
|
|
|
|
if total_num_frames > video_mirror_clip_length: |
|
|
idx = random.randint(0, total_num_frames-video_mirror_clip_length) |
|
|
mapping = list(range(idx, idx + video_mirror_clip_length)) |
|
|
total_num_frames = video_mirror_clip_length |
|
|
else: |
|
|
mapping = list(range(total_num_frames)) |
|
|
n_repeat = max((sampling_num_frames * stride - total_num_frames) // (total_num_frames - 1), 0) + 1 |
|
|
mapping_repeat = mapping.copy() |
|
|
for i in range(n_repeat): |
|
|
if i % 2 == 0: |
|
|
mapping_repeat += mapping[-2::-1] |
|
|
else: |
|
|
mapping_repeat += mapping[1:] |
|
|
if start_index is None: |
|
|
start_index = random.randint(0, len(mapping_repeat) - sampling_num_frames * stride) |
|
|
sample_idx = list(range(start_index, start_index + sampling_num_frames * stride, stride)) |
|
|
sample_idx = [mapping_repeat[idx] for idx in sample_idx] |
|
|
return sample_idx |
|
|
|
|
|
def weighted_sample(arr: np.ndarray, num_samples: int, |
|
|
bias: Union[str, Callable[[np.ndarray], np.ndarray]] = 'uniform') -> np.ndarray: |
|
|
""" |
|
|
Sample elements from a numpy array with optional weighting toward the end. |
|
|
|
|
|
Parameters: |
|
|
- arr: np.ndarray - Input array to sample from. |
|
|
- num_samples: int - Number of samples to draw. |
|
|
- bias: str or callable - Weighting strategy: 'uniform', 'linear', 'squared', 'exponential', |
|
|
or a custom function that takes an array of [0, 1] positions. |
|
|
|
|
|
Returns: |
|
|
- np.ndarray - Sampled array of values. |
|
|
- np.ndarray - Sampled array of probabilities. |
|
|
""" |
|
|
n = len(arr) |
|
|
|
|
|
if bias == 'uniform': |
|
|
probabilities = None |
|
|
else: |
|
|
|
|
|
positions = np.linspace(0, 1, n) |
|
|
|
|
|
if bias == 'linear': |
|
|
weights = positions |
|
|
elif bias == 'squared': |
|
|
weights = np.square(positions) |
|
|
elif bias == 'exponential': |
|
|
weights = np.exp(3 * positions) |
|
|
elif callable(bias): |
|
|
weights = bias(positions) |
|
|
else: |
|
|
raise ValueError("Invalid bias type. Use 'uniform', 'linear', 'squared', 'exponential', or a custom callable.") |
|
|
|
|
|
probabilities = weights / weights.sum() |
|
|
|
|
|
sampled_inds = np.random.choice(np.arange(len(arr)), size=num_samples, replace=False, p=probabilities) |
|
|
sampled_vals = arr[sampled_inds] |
|
|
sampled_probabilities = probabilities[sampled_inds] if probabilities is not None else probabilities |
|
|
return sampled_vals, sampled_probabilities |
|
|
|
|
|
def read_exr_depth_to_numpy(exr_file) -> np.ndarray: |
|
|
header = exr_file.header() |
|
|
dw = header["dataWindow"] |
|
|
h = dw.max.y - dw.min.y + 1 |
|
|
w = dw.max.x - dw.min.x + 1 |
|
|
|
|
|
|
|
|
chan_info = header['channels']['Z'] |
|
|
pix_type = chan_info.type.v |
|
|
|
|
|
|
|
|
if pix_type == Imath.PixelType(Imath.PixelType.HALF).v: |
|
|
dtype = np.float16 |
|
|
bytes_per_pixel = 2 |
|
|
elif pix_type == Imath.PixelType(Imath.PixelType.FLOAT).v: |
|
|
dtype = np.float32 |
|
|
bytes_per_pixel = 4 |
|
|
elif pix_type == Imath.PixelType(Imath.PixelType.DOUBLE).v: |
|
|
dtype = np.float64 |
|
|
bytes_per_pixel = 8 |
|
|
else: |
|
|
raise ValueError(f"Unknown EXR pixel type: {pix_type}") |
|
|
|
|
|
|
|
|
raw = exr_file.channel("Z") |
|
|
depth_map = np.frombuffer(raw, dtype=dtype).reshape(h, w) |
|
|
return depth_map |
|
|
|
|
|
def merge_input_target_data_dicts(data_fields_input, data_fields_target, original_output_dict_input, original_output_dict_target): |
|
|
original_output_dict = {} |
|
|
data_fields = list(set(data_fields_input + data_fields_target)) |
|
|
for data_field in data_fields: |
|
|
if data_field in original_output_dict_input and data_field in original_output_dict_target: |
|
|
out_data_field = torch.cat((original_output_dict_input[data_field], original_output_dict_target[data_field])) |
|
|
else: |
|
|
if data_field in original_output_dict_input: |
|
|
out_data_field = original_output_dict_input[data_field] |
|
|
elif data_field in original_output_dict_target: |
|
|
out_data_field = original_output_dict_target[data_field] |
|
|
original_output_dict[data_field] = out_data_field |
|
|
return original_output_dict |
|
|
|
|
|
def write_dict_to_json(data, filename): |
|
|
with open(filename, 'w') as json_file: |
|
|
json.dump(data, json_file, indent=4) |
|
|
|
|
|
def read_json_to_dict(filename): |
|
|
with open(filename, "r") as f: |
|
|
json_dict = json.load(f) |
|
|
return json_dict |