Lyra / src /models /utils /data.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from packaging import version as pver
import random
import json
from typing import Union, Callable, Tuple
import OpenEXR
import Imath
class ImageTransform:
def __init__(self, crop_size, sample_size, max_crop, use_flip=False, ):
self.use_flip = use_flip
self.crop_size = crop_size
self.max_crop = max_crop
self.sample_size = sample_size
self.crop_transform = transforms.CenterCrop(crop_size) if crop_size else lambda x: x
self.resize_transform = (
transforms.Resize(sample_size) if sample_size else lambda x: x
)
self.resize_transform_depth = (
transforms.Resize(sample_size, interpolation=transforms.InterpolationMode.NEAREST) if sample_size else lambda x: x
)
def preprocess_images(self, images, depths=None):
# Returns the preprocessed images along with an image transform object
# which describes the transformation on the image
video = images
if self.use_flip:
assert False
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
else:
flip_flag = torch.zeros(
images.shape[0], dtype=torch.bool, device=video.device
)
ori_h, ori_w = video.shape[-2:]
if self.max_crop:
# scale up to largest croppable size
crop_ratio = min(ori_h/self.crop_size[0], ori_w/self.crop_size[1])
new_crop_size = (int(self.crop_size[0]*crop_ratio), int(self.crop_size[1]*crop_ratio))
self.crop_transform = transforms.CenterCrop(new_crop_size)
video = self.crop_transform(video)
if depths is not None:
depths = self.crop_transform(depths)
# print('after crop',video.shape)
new_h, new_w = video.shape[-2:]
# NOTE! I'm using u,v convention here instead of h,w
shift = ((new_w - ori_w) / 2, (new_h - ori_h) / 2)
# resize:
ori_h, ori_w = video.shape[-2:]
# new_h, new_w = self.sample_size
video = self.resize_transform(video)
if depths is not None:
depths = self.resize_transform_depth(depths)
new_h, new_w = video.shape[-2:]
scale = (new_w/ori_w, new_h/ori_h)
if self.use_flip:
video = self.flip_transform(video, flip_flag)
if depths is not None:
depths = self.flip_transform(depths)
# print('shift, scale',shift, scale)
# return video, shift, scale, flip_flag
return video, depths, shift, scale, flip_flag
def apply_img_transform(self, i, j, shift, scale):
# takes pixel uv coordinates in un-transformed space and converts to new
# coordinates of image after crop and resize
# first shift, then scale
i = (i + shift[0]) * scale[0]
j = (j + shift[1]) * scale[1]
return i, j
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def get_grid_uvs(batch_shape, H, W, device, dtype=None, flip_flag=None, nh=None, nw=None, margin=0):
if dtype is None: dtype = torch.float32
if nh is None: nh = H
if nw is None: nw = W
# c2w: B, V, 4, 4
# K: B, V, 4
# c2w @ dirctions
B, V = batch_shape
j, i = custom_meshgrid(
torch.linspace(0, H - 1, nh, device=device, dtype=dtype),
torch.linspace(0, W - 1, nw, device=device, dtype=dtype),
)
i = i.reshape([1, 1, nh * nw]).expand([B, V, nh * nw]) + 0.5 # [B, V, HxW]
j = j.reshape([1, 1, nh * nw]).expand([B, V, nh * nw]) + 0.5 # [B, V, HxW]
if margin != 0:
marginw = 1-2*margin
i = marginw * i + margin * W
j = marginw * j + margin * H
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
if n_flip > 0:
j_flip, i_flip = custom_meshgrid(
torch.linspace(0, H - 1, nh, device=device, dtype=dtype),
torch.linspace(W - 1, 0, nw, device=device, dtype=dtype)
)
i_flip = i_flip.reshape([1, 1, nh * nw]).expand(B, 1, nh * nw) + 0.5
j_flip = j_flip.reshape([1, 1, nh * nw]).expand(B, 1, nh * nw) + 0.5
i[:, flip_flag, ...] = i_flip
j[:, flip_flag, ...] = j_flip
return i,j
def get_rays_from_uvs(i,j,K,c2w):
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, V, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
# printarr(directions, c2w)
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
return rays_o, rays_d
def project_to_uvs(pts,K,c2w):
w2c = torch.linalg.inv(c2w)
cam_pts = torch.einsum("...ij,...vj->...vi", w2c[...,:3,:3], pts) + w2c[..., None, :3, 3]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
xs = cam_pts[...,0]
ys = cam_pts[...,1]
zs = cam_pts[...,2]
us = (fx * xs/zs) + cx
vs = (fy * ys/zs) + cy
uvs = torch.stack([us,vs], dim=-1)
return uvs, zs
def get_rays(K, c2w, H, W, device, flip_flag=None, nh=None, nw=None):
i,j = get_grid_uvs(K.shape[:2], H=H, W=W, dtype=K.dtype, device=device, flip_flag=flip_flag, nh=nh, nw=nw)
return get_rays_from_uvs(i,j,K,c2w)
def ray_condition(K, c2w, H, W, device, flip_flag=None, get_batch_index=True):
batch_shape = K.shape[:2]
B, V = batch_shape
rays_o, rays_d = get_rays(K, c2w, H, W, device, flip_flag=flip_flag)
rays_dxo = torch.cross(rays_o, rays_d, dim=-1) # B, V, HW, 3
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6).permute(0, 1, 4, 2, 3).contiguous()
rays_o = rays_o.reshape(B, c2w.shape[1], H, W, 3).permute(0, 1, 4, 2, 3).contiguous()
rays_d = rays_d.reshape(B, c2w.shape[1], H, W, 3).permute(0, 1, 4, 2, 3).contiguous()
if get_batch_index:
plucker = plucker[0]
rays_o = rays_o[0]
rays_d = rays_d[0]
return plucker, rays_o, rays_d
def mirror_frame_indices(sampling_num_frames: int, total_num_frames: int, video_mirror_clip_length: int = None, stride: int = 1, start_index: int = None, return_target: bool = False):
if video_mirror_clip_length is None:
video_mirror_clip_length = total_num_frames
if total_num_frames > video_mirror_clip_length:
idx = random.randint(0, total_num_frames-video_mirror_clip_length)
mapping = list(range(idx, idx + video_mirror_clip_length))
total_num_frames = video_mirror_clip_length
else:
mapping = list(range(total_num_frames))
n_repeat = max((sampling_num_frames * stride - total_num_frames) // (total_num_frames - 1), 0) + 1
mapping_repeat = mapping.copy()
for i in range(n_repeat):
if i % 2 == 0:
mapping_repeat += mapping[-2::-1]
else:
mapping_repeat += mapping[1:]
if start_index is None:
start_index = random.randint(0, len(mapping_repeat) - sampling_num_frames * stride)
sample_idx = list(range(start_index, start_index + sampling_num_frames * stride, stride))
sample_idx = [mapping_repeat[idx] for idx in sample_idx]
return sample_idx
def weighted_sample(arr: np.ndarray, num_samples: int,
bias: Union[str, Callable[[np.ndarray], np.ndarray]] = 'uniform') -> np.ndarray:
"""
Sample elements from a numpy array with optional weighting toward the end.
Parameters:
- arr: np.ndarray - Input array to sample from.
- num_samples: int - Number of samples to draw.
- bias: str or callable - Weighting strategy: 'uniform', 'linear', 'squared', 'exponential',
or a custom function that takes an array of [0, 1] positions.
Returns:
- np.ndarray - Sampled array of values.
- np.ndarray - Sampled array of probabilities.
"""
n = len(arr)
if bias == 'uniform':
probabilities = None # uniform by default in np.random.choice
else:
# Relative position in array: from 0 (start) to 1 (end)
positions = np.linspace(0, 1, n)
if bias == 'linear':
weights = positions
elif bias == 'squared':
weights = np.square(positions)
elif bias == 'exponential':
weights = np.exp(3 * positions) # the factor controls steepness; adjust as needed
elif callable(bias):
weights = bias(positions)
else:
raise ValueError("Invalid bias type. Use 'uniform', 'linear', 'squared', 'exponential', or a custom callable.")
probabilities = weights / weights.sum()
sampled_inds = np.random.choice(np.arange(len(arr)), size=num_samples, replace=False, p=probabilities)
sampled_vals = arr[sampled_inds]
sampled_probabilities = probabilities[sampled_inds] if probabilities is not None else probabilities
return sampled_vals, sampled_probabilities
def read_exr_depth_to_numpy(exr_file) -> np.ndarray:
header = exr_file.header()
dw = header["dataWindow"]
h = dw.max.y - dw.min.y + 1
w = dw.max.x - dw.min.x + 1
# Dynamically detect pixel type
chan_info = header['channels']['Z']
pix_type = chan_info.type.v
# Map OpenEXR pixel types to NumPy dtypes
if pix_type == Imath.PixelType(Imath.PixelType.HALF).v:
dtype = np.float16
bytes_per_pixel = 2
elif pix_type == Imath.PixelType(Imath.PixelType.FLOAT).v:
dtype = np.float32
bytes_per_pixel = 4
elif pix_type == Imath.PixelType(Imath.PixelType.DOUBLE).v:
dtype = np.float64
bytes_per_pixel = 8
else:
raise ValueError(f"Unknown EXR pixel type: {pix_type}")
# Read and reshape
raw = exr_file.channel("Z")
depth_map = np.frombuffer(raw, dtype=dtype).reshape(h, w)
return depth_map
def merge_input_target_data_dicts(data_fields_input, data_fields_target, original_output_dict_input, original_output_dict_target):
original_output_dict = {}
data_fields = list(set(data_fields_input + data_fields_target))
for data_field in data_fields:
if data_field in original_output_dict_input and data_field in original_output_dict_target:
out_data_field = torch.cat((original_output_dict_input[data_field], original_output_dict_target[data_field]))
else:
if data_field in original_output_dict_input:
out_data_field = original_output_dict_input[data_field]
elif data_field in original_output_dict_target:
out_data_field = original_output_dict_target[data_field]
original_output_dict[data_field] = out_data_field
return original_output_dict
def write_dict_to_json(data, filename):
with open(filename, 'w') as json_file:
json.dump(data, json_file, indent=4)
def read_json_to_dict(filename):
with open(filename, "r") as f:
json_dict = json.load(f)
return json_dict