vmem / utils /util.py
Jiahua0's picture
Upload folder using huggingface_hub
ff47419 verified
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import kornia
from matplotlib import cm
from torchvision.io import write_video
from PIL import Image, ImageOps
import os
from typing import Union, Tuple, List
import math
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default
def get_default_intrinsics(
fov_rad=DEFAULT_FOV_RAD,
aspect_ratio=1.0,
):
if not isinstance(fov_rad, torch.Tensor):
fov_rad = torch.tensor(
[fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad
)
if aspect_ratio >= 1.0: # W >= H
focal_x = 0.5 / torch.tan(0.5 * fov_rad)
focal_y = focal_x * aspect_ratio
else: # W < H
focal_y = 0.5 / torch.tan(0.5 * fov_rad)
focal_x = focal_y / aspect_ratio
intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3))
intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack(
[focal_x, focal_y, torch.ones_like(focal_x)], dim=-1
)
intrinsics[:, :, -1] = torch.tensor(
[0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype
)
return intrinsics
def to_hom(X):
# get homogeneous coordinates of the input
X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
return X_hom
def to_hom_pose(pose):
# get homogeneous coordinates of the input pose
if pose.shape[-2:] == (3, 4):
pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1)
pose_hom[:, :3, :] = pose
return pose_hom
return pose
def get_image_grid(img_h, img_w):
# add 0.5 is VERY important especially when your img_h and img_w
# is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5)
x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5)
Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W]
xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2]
return to_hom(xy_grid) # [HW,3]
def img2cam(X, cam_intr):
return X @ cam_intr.inverse().transpose(-1, -2)
def cam2world(X, pose):
X_hom = to_hom(X)
pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4]
return X_hom @ pose_inv.transpose(-1, -2)
def get_center_and_ray(img_h, img_w, pose, intr): # [HW,2]
# given the intrinsic/extrinsic matrices, get the camera center and ray directions]
# assert(opt.camera.model=="perspective")
# compute center and ray
grid_img = get_image_grid(img_h, img_w) # [HW,3]
grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3]
center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3]
# transform from camera to world coordinates
grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3]
center_3D = cam2world(center_3D_cam, pose) # [B,HW,3]
ray = grid_3D - center_3D # [B,HW,3]
return center_3D, ray, grid_3D_cam
def get_plucker_coordinates(
extrinsics_src,
extrinsics,
intrinsics=None,
fov_rad=DEFAULT_FOV_RAD,
target_size=[72, 72],
):
# Support for batch dimension
has_batch_dim = len(extrinsics.shape) == 4
if has_batch_dim:
# [B, N, 4, 4] -> reshape to handle batch
batch_size, num_cameras = extrinsics.shape[0:2]
extrinsics_flat = extrinsics.reshape(-1, *extrinsics.shape[2:])
# Handle extrinsics_src appropriately
if len(extrinsics_src.shape) == 3: # [B, 4, 4]
extrinsics_src_expanded = extrinsics_src.unsqueeze(1).expand(-1, num_cameras, -1, -1)
extrinsics_src_flat = extrinsics_src_expanded.reshape(-1, *extrinsics_src.shape[1:])
else: # [4, 4] - single extrinsics_src for all batches
extrinsics_src_flat = extrinsics_src.expand(batch_size * num_cameras, -1, -1)
# Handle intrinsics for batch
if intrinsics is None:
intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
intrinsics = intrinsics.expand(batch_size * num_cameras, -1, -1)
elif len(intrinsics.shape) == 3: # [N, 3, 3]
if intrinsics.shape[0] == num_cameras:
intrinsics = intrinsics.expand(batch_size, -1, -1, -1).reshape(-1, *intrinsics.shape[1:])
else:
intrinsics = intrinsics.expand(batch_size * num_cameras, -1, -1)
elif len(intrinsics.shape) == 4: # [B, N, 3, 3]
intrinsics = intrinsics.reshape(-1, *intrinsics.shape[2:])
else:
# Original behavior for non-batch input
extrinsics_flat = extrinsics
extrinsics_src_flat = extrinsics_src
if intrinsics is None:
intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
# Process intrinsics normalization
if not (
torch.all(intrinsics[:, :2, -1] >= 0)
and torch.all(intrinsics[:, :2, -1] <= 1)
):
intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8
# Ensure normalized intrinsics
assert (
torch.all(intrinsics[:, :2, -1] >= 0)
and torch.all(intrinsics[:, :2, -1] <= 1)
), "Intrinsics should be expressed in resolution-independent normalized image coordinates."
c2w_src = torch.linalg.inv(extrinsics_src_flat)
# transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera
extrinsics_rel = torch.einsum(
"vnm,vmp->vnp", extrinsics_flat, c2w_src
)
intrinsics[:, :2] *= extrinsics_flat.new_tensor(
[
target_size[1], # w
target_size[0], # h
]
).view(1, -1, 1)
centers, rays, grid_cam = get_center_and_ray(
img_h=target_size[0],
img_w=target_size[1],
pose=extrinsics_rel[:, :3, :],
intr=intrinsics,
)
rays = torch.nn.functional.normalize(rays, dim=-1)
plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1)
plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size)
# Reshape back to batch dimension if needed
if has_batch_dim:
plucker = plucker.reshape(batch_size, num_cameras, *plucker.shape[1:])
return plucker
def get_value_dict(
curr_imgs,
curr_imgs_clip,
curr_input_frame_indices,
curr_c2ws,
curr_Ks,
curr_input_camera_indices,
all_c2ws,
camera_scale,
):
assert sorted(curr_input_camera_indices) == sorted(
range(len(curr_input_camera_indices))
)
H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
value_dict = {}
value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["cond_frames_mask"][curr_input_frame_indices] = True
value_dict["cond_aug"] = 0.0
if curr_c2ws.shape[-1] == 3:
c2w = to_hom_pose(curr_c2ws.float())
else:
c2w = curr_c2ws
w2c = torch.linalg.inv(c2w)
# camera centering
ref_c2ws = all_c2ws
camera_dist_2med = torch.norm(
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
dim=-1,
)
valid_mask = camera_dist_2med <= torch.clamp(
torch.quantile(camera_dist_2med, 0.97) * 10,
max=1e6,
)
c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
w2c = torch.linalg.inv(c2w)
# camera normalization
camera_dists = c2w[:, :3, 3].clone()
translation_scaling_factor = (
camera_scale
if torch.isclose(
torch.norm(camera_dists[0]),
torch.zeros(1),
atol=1e-5,
).any()
else (camera_scale / torch.norm(camera_dists[0]))
)
w2c[:, :3, 3] *= translation_scaling_factor
c2w[:, :3, 3] *= translation_scaling_factor
value_dict["plucker_coordinate"] = get_plucker_coordinates(
extrinsics_src=w2c[0],
extrinsics=w2c,
intrinsics=curr_Ks.float().clone(),
target_size=(H // F, W // F),
)
value_dict["c2w"] = c2w
value_dict["K"] = curr_Ks
value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["camera_mask"][curr_input_camera_indices] = True
return value_dict
def parse_meta_data(file_path, image_height=288, image_width=512):
with open(file_path, 'r') as file:
lines = file.readlines()
# First line is the video URL
video_url = lines[0].strip()
line = lines[1]
data = line.strip().split()
# Construct the camera intrinsics matrix K
focal_length_x = float(data[1])
focal_length_y = float(data[2])
principal_point_x = float(data[3])
principal_point_y = float(data[4])
original_K = [
[focal_length_x, 0, principal_point_x],
[0, focal_length_y, principal_point_y],
[0, 0, 1]
]
K = [
[focal_length_x * image_width, 0, principal_point_x * image_width],
[0, focal_length_y * image_height, principal_point_y * image_height],
[0, 0, 1]
]
# Initialize a list to store frame data
timestamp_to_c2ws = {}
timestamps = []
# Process each frame line
for line in lines[1:]:
data = line.strip().split()
timestamp = int(data[0])
R_t = [float(x) for x in data[7:]]
P = [
R_t[0:4],
R_t[4:8],
R_t[8:12],
[0, 0, 0, 1]
]
timestamp_to_c2ws[timestamp] = np.array(P)
timestamps.append(timestamp)
return timestamps, np.array(K), timestamp_to_c2ws, original_K
def get_wh_with_fixed_shortest_side(w, h, size):
# size is smaller or equal to zero, we return original w h
if size is None or size <= 0:
return w, h
if w < h:
new_w = size
new_h = int(size * h / w)
else:
new_h = size
new_w = int(size * w / h)
return new_w, new_h
def get_resizing_factor(
target_shape: Tuple[int, int], # H, W
current_shape: Tuple[int, int], # H, W
cover_target: bool = True,
# If True, the output shape will fully cover the target shape.
# If No, the target shape will fully cover the output shape.
) -> float:
r_bound = target_shape[1] / target_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if cover_target:
if aspect_r >= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r < 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r >= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r < 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
else:
if cover_target:
if aspect_r <= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r > 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r > 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
return factor
def transform_img_and_K(
image: torch.Tensor,
size: Union[int, Tuple[int, int]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: Union[torch.Tensor, np.ndarray, None] = None,
size_stride: int = 1,
mode: str = "crop",
):
assert mode in [
"crop",
"pad",
"stretch",
], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
h, w = image.shape[-2:]
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
if mode == "stretch":
rh, rw = H, W
else:
rfs = get_resizing_factor(
(H, W),
(h, w),
cover_target=mode != "pad",
)
(rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
rh, rw = int(rh / scale), int(rw / scale)
image = torch.nn.functional.interpolate(
image, (rh, rw), mode="area", antialias=False
)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if mode != "pad":
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
pl, pt = 0, 0
else:
pt = max(0, H // 2 - cy_center)
pl = max(0, W // 2 - cx_center)
pb = max(0, H - pt - image.shape[-2])
pr = max(0, W - pl - image.shape[-1])
image = TF.pad(
image,
[pl, pt, pr, pb],
)
cl, ct = 0, 0
if K is not None:
K = K.clone()
# K[:, :2, 2] += K.new_tensor([pl, pt])
if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
else:
K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
return image, K
def load_img_and_K(
image_path_or_size: Union[str, torch.Size],
size: Optional[Union[int, Tuple[int, int]]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: Union[torch.Tensor, np.ndarray, None] = None,
size_stride: int = 1,
center_crop: bool = False,
image_as_tensor: bool = True,
context_rgb: Union[np.ndarray, None] = None,
device: str = "cuda",
):
if isinstance(image_path_or_size, torch.Size):
image = Image.new("RGBA", image_path_or_size[::-1])
else:
image = Image.open(image_path_or_size).convert("RGBA")
w, h = image.size
if size is None:
size = (w, h)
image = np.array(image).astype(np.float32) / 255
if image.shape[-1] == 4:
rgb, alpha = image[:, :, :3], image[:, :, 3:]
if context_rgb is not None:
image = rgb * alpha + context_rgb * (1 - alpha)
else:
image = rgb * alpha + (1 - alpha)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
if scale < 1.0:
pw = math.ceil((W - resize_size[1]) * 0.5)
ph = math.ceil((H - resize_size[0]) * 0.5)
image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if center_crop:
side = min(H, W)
ct = max(0, cy_center - side // 2)
cl = max(0, cx_center - side // 2)
ct = min(ct, image.shape[-2] - side)
cl = min(cl, image.shape[-1] - side)
image = TF.crop(image, top=ct, left=cl, height=side, width=side)
else:
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
if K is not None:
K = K.clone()
if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
else:
K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
K[:2, 2] -= K.new_tensor([cl, ct])
if image_as_tensor:
# tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
image = image.to(device) * 2.0 - 1.0
else:
# PIL Image with values ranging from (0, 255)
image = image.permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).astype(np.uint8))
return image, K
def geodesic_distance(extrinsic1: Union[np.ndarray, torch.Tensor],
extrinsic2: Union[np.ndarray, torch.Tensor],
weight_translation: float = 0.01,):
"""
Computes the geodesic distance between two camera poses in SE(3).
Parameters:
extrinsic1 (Union[np.ndarray, torch.Tensor]): 4x4 extrinsic matrix of the first pose.
extrinsic2 (Union[np.ndarray, torch.Tensor]): 4x4 extrinsic matrix of the second pose.
Returns:
Union[float, torch.Tensor]: Geodesic distance between the two poses.
"""
if torch.is_tensor(extrinsic1):
# Extract the rotation and translation components
R1 = extrinsic1[:3, :3]
t1 = extrinsic1[:3, 3]
R2 = extrinsic2[:3, :3]
t2 = extrinsic2[:3, 3]
# Compute the translation distance (Euclidean distance)
translation_distance = torch.norm(t1 - t2)
# Compute the relative rotation matrix
R_relative = torch.matmul(R1.T, R2)
# Compute the angular distance from the trace of the relative rotation matrix
trace_value = torch.trace(R_relative)
# Clamp the trace value to avoid numerical issues
trace_value = torch.clamp(trace_value, -1.0, 3.0)
angular_distance = torch.acos((trace_value - 1) / 2)
else:
# Extract the rotation and translation components
R1 = extrinsic1[:3, :3]
t1 = extrinsic1[:3, 3]
R2 = extrinsic2[:3, :3]
t2 = extrinsic2[:3, 3]
# Compute the translation distance (Euclidean distance)
translation_distance = np.linalg.norm(t1 - t2)
# Compute the relative rotation matrix
R_relative = np.dot(R1.T, R2)
# Compute the angular distance from the trace of the relative rotation matrix
trace_value = np.trace(R_relative)
# Clamp the trace value to avoid numerical issues
trace_value = np.clip(trace_value, -1.0, 3.0)
angular_distance = np.arccos((trace_value - 1) / 2)
# Combine the two distances
geodesic_dist = translation_distance*weight_translation + angular_distance
return geodesic_dist
def inverse_geodesic_distance(extrinsic1,
extrinsic2,
weight_translation=0.01):
"""
Computes the inverse geodesic distance between two camera poses in SE(3).
Parameters:
extrinsic1 (np.ndarray): 4x4 extrinsic matrix of the first pose.
extrinsic2 (np.ndarray): 4x4 extrinsic matrix of the second pose.
Returns:
float: Inverse geodesic distance between the two poses.
"""
# Compute the geodesic distance
geodesic_dist = geodesic_distance(extrinsic1, extrinsic2, weight_translation)
# Compute the inverse geodesic distance
inverse_geodesic_dist = 1.0 / (geodesic_dist + 1e-6)
return inverse_geodesic_dist
def average_camera_pose(camera_poses):
"""
Compute a better average of camera poses in SE(3).
Args:
camera_poses: List or array of camera poses, each a 4x4 matrix
Returns:
Average camera pose as a 4x4 matrix
"""
# Extract rotation and translation components
rotations = camera_poses[:, :3, :3].detach().cpu().numpy()
translations = camera_poses[:, :3, 3].detach().cpu().numpy()
# Average translation with simple mean
avg_translation = np.mean(translations, axis=0)
# Convert rotations to quaternions for better averaging
import scipy.spatial.transform as transform
quats = [transform.Rotation.from_matrix(R).as_quat() for R in rotations]
# Ensure quaternions are in the same hemisphere to avoid issues with averaging
for i in range(1, len(quats)):
if np.dot(quats[0], quats[i]) < 0:
quats[i] = -quats[i]
# Average the quaternions and convert back to rotation matrix
avg_quat = np.mean(quats, axis=0)
avg_quat = avg_quat / np.linalg.norm(avg_quat) # Normalize
avg_rotation = transform.Rotation.from_quat(avg_quat).as_matrix()
# Construct the average pose
avg_pose = np.eye(4)
avg_pose[:3, :3] = avg_rotation
avg_pose[:3, 3] = avg_translation
return avg_pose
def encode_image(
image,
image_encoder,
device,
dtype,
) -> torch.Tensor:
image = image.to(device=device, dtype=dtype)
image_embeddings = image_encoder(image)
return image_embeddings
def encode_vae_image(
image,
vae,
device,
dtype,
):
image = image.to(device=device, dtype=dtype)
image_latents = vae.encode(image, 1)
return image_latents
def do_sample(
model,
ae,
denoiser,
sampler,
c,
uc,
c2w,
K,
cond_frames_mask,
H=576,
W=768,
C=4,
F=8,
T=8,
cfg=2.0,
decoding_t=1,
verbose=True,
global_pbar=None,
return_latents=False,
device: str = "cuda",
**_,
):
num_samples = [1, T]
with torch.inference_mode(), torch.autocast("cuda"):
additional_model_inputs = {"num_frames": T}
additional_sampler_inputs = {
"c2w": c2w.to("cuda"),
"K": K.to("cuda"),
"input_frame_mask": cond_frames_mask.to("cuda"),
}
if global_pbar is not None:
additional_sampler_inputs["global_pbar"] = global_pbar
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device)
samples_z = sampler(
lambda input, sigma, c: denoiser(
model,
input,
sigma,
c,
**additional_model_inputs,
),
randn,
scale=cfg,
cond=c,
uc=uc,
verbose=verbose,
**additional_sampler_inputs,
)
if samples_z is None:
return
samples = ae.decode(samples_z, decoding_t)
if return_latents:
return samples, samples_z
return samples
def decode_output(
samples,
T,
indices=None,
):
# decode model output into dict if it is not
if isinstance(samples, dict):
# model with postprocessor and outputs dict q``
for sample, value in samples.items():
if isinstance(value, torch.Tensor):
value = value.detach().cpu()
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
value = torch.tensor(value)
if indices is not None and value.shape[0] == T:
value = value[indices]
samples[sample] = value
else:
# model without postprocessor and outputs tensor (rgb)
samples = samples.detach().cpu()
if indices is not None and samples.shape[0] == T:
samples = samples[indices]
samples = {"samples-rgb/image": samples}
return samples
def select_frames(timestamps, min_num_frames=2, skip_frame=10, random_start=False):
"""
Select frames from a video sequence based on defined criteria.
Args:
timestamps: List of timestamps for the frames
min_num_frames: Minimum number of frames required
skip_frame: Number of frames to skip between selections
random_start: If True, start from a random frame
Returns:
tuple: (selected_frame_indices, selected_frame_timestamps) or (None, None) if criteria not met
"""
num_frames = len(timestamps)
if num_frames < min_num_frames:
print(f"[Worker PID={os.getpid()}] Episode has less than {min_num_frames} frames")
return None, None
# Decide on start/end frames
if num_frames < 2:
print(f"[Worker PID={os.getpid()}] Episode has less than 2 frames")
return None, None
elif num_frames < skip_frame:
cur_skip_frame = num_frames - 1
else:
cur_skip_frame = skip_frame
if random_start:
start_frame = np.random.randint(0, skip_frame)
else:
start_frame = 0
# Gather frame indices
selected_frame_indices = list(range(start_frame, num_frames, cur_skip_frame))
selected_frame_timestamps = [timestamps[i] for i in selected_frame_indices]
return selected_frame_indices, selected_frame_timestamps
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
class LatentStorer:
def __init__(self):
self.latent = None
def __call__(self, i, t, latent):
self.latent = latent
def sobel_filter(disp, mode="sobel", beta=10.0):
sobel_grad = kornia.filters.spatial_gradient(disp, mode=mode, normalized=False)
sobel_mag = torch.sqrt(sobel_grad[:, :, 0, Ellipsis] ** 2 + sobel_grad[:, :, 1, Ellipsis] ** 2)
alpha = torch.exp(-1.0 * beta * sobel_mag).detach()
return alpha
def apply_colormap(image, cmap="viridis"):
colormap = cm.get_cmap(cmap)
colormap = torch.tensor(colormap.colors).to(image.device)
image_long = (image * 255).long()
image_long_min = torch.min(image_long)
image_long_max = torch.max(image_long)
assert image_long_min >= 0, f"the min value is {image_long_min}"
assert image_long_max <= 255, f"the max value is {image_long_max}"
return colormap[image_long[..., 0]]
def apply_depth_colormap(
depth,
near_plane=None,
far_plane=None,
cmap="viridis",
):
near_plane = near_plane or float(torch.min(depth))
far_plane = far_plane or float(torch.max(depth))
depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
depth = torch.clip(depth, 0, 1)
colored_image = apply_colormap(depth, cmap=cmap)
return colored_image
def save_video(video, path, fps=10):
video = video.permute(0, 2, 3, 1)
video_codec = "libx264"
video_options = {
"crf": "23", # Constant Rate Factor (lower value = higher quality, 18 is a good balance)
"preset": "slow",
}
write_video(str(path), video, fps=fps, video_codec=video_codec, options=video_options)
def visualize_camera_poses(camera_poses, axis_length=0.1):
"""
Visualizes a set of camera poses in 3D using Matplotlib.
Parameters
----------
camera_poses : np.ndarray
An array of shape (N, 4, 4) containing N camera poses.
Each pose is a 4x4 transformation matrix.
axis_length : float
Length of the camera axes to draw.
"""
if isinstance(camera_poses, torch.Tensor):
camera_poses = camera_poses.detach().cpu().numpy()
# Create a 3D figure
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Iterate over all camera poses
for i in range(camera_poses.shape[0]):
# Extract rotation (R) and translation (t)
R = camera_poses[i][:3, :3]
t = camera_poses[i][:3, 3]
# Plot the camera center
ax.scatter(t[0], t[1], t[2], c='k', marker='o', s=20)
# Define the end-points of each local axis
x_axis_end = t + R[:, 0] * axis_length
y_axis_end = t + R[:, 1] * axis_length
z_axis_end = t + R[:, 2] * axis_length
# Draw the axes as lines
ax.plot([t[0], x_axis_end[0]], [t[1], x_axis_end[1]],
[t[2], x_axis_end[2]], color='r') # X-axis (red)
ax.plot([t[0], y_axis_end[0]], [t[1], y_axis_end[1]],
[t[2], y_axis_end[2]], color='g') # Y-axis (green)
ax.plot([t[0], z_axis_end[0]], [t[1], z_axis_end[1]],
[t[2], z_axis_end[2]], color='b') # Z-axis (blue)
# Make axes have equal scale
set_axes_equal(ax)
ax.set_title("Camera Poses Visualization")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
plt.show()
def set_axes_equal(ax):
"""
Make axes of 3D plot have equal scale so that spheres appear as spheres, cubes as cubes, etc.
This is a workaround to Matplotlib's set_aspect('equal') which is not supported in 3D.
"""
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = x_limits[1] - x_limits[0]
y_range = y_limits[1] - y_limits[0]
z_range = z_limits[1] - z_limits[0]
max_range = max(x_range, y_range, z_range)
x_middle = np.mean(x_limits)
y_middle = np.mean(y_limits)
z_middle = np.mean(z_limits)
ax.set_xlim3d([x_middle - 0.5 * max_range, x_middle + 0.5 * max_range])
ax.set_ylim3d([y_middle - 0.5 * max_range, y_middle + 0.5 * max_range])
ax.set_zlim3d([z_middle - 0.5 * max_range, z_middle + 0.5 * max_range])
def tensor_to_pil(image):
if isinstance(image, torch.Tensor):
if image.dim() == 4:
image = image.squeeze(0)
image = image.permute(1, 2, 0).detach().cpu().numpy()
# Detect the range of the input tensor
if image.min() < -0.1: # If we have negative values, assume [-1, 1] range
image = (image + 1) / 2.0 # Convert from [-1, 1] to [0, 1]
# Otherwise, assume it's already in [0, 1] range
image = (image * 255)
image = np.clip(image, 0, 255)
image = image.astype(np.uint8)
return Image.fromarray(image)
def center_crop_pil_image(input_image, target_width=1024, target_height=576):
w, h = input_image.size
h_ratio = h / target_height
w_ratio = w / target_width
if h_ratio > w_ratio:
h = int(h / w_ratio)
if h < target_height:
h = target_height
input_image = input_image.resize((target_width, h), Image.Resampling.LANCZOS)
else:
w = int(w / h_ratio)
if w < target_width:
w = target_width
input_image = input_image.resize((w, target_height), Image.Resampling.LANCZOS)
return ImageOps.fit(input_image, (target_width, target_height), Image.BICUBIC)
def resize_pil_image(img, long_edge_size):
S = max(img.size)
if S > long_edge_size:
interp = PIL.Image.LANCZOS
elif S <= long_edge_size:
interp = PIL.Image.BICUBIC
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
return img.resize(new_size, interp)
def visualize_surfels(
surfels,
draw_normals=False,
normal_scale=20,
disk_resolution=16,
disk_alpha=0.5
):
"""
Visualize surfels as 2D disks oriented by their normals in 3D using matplotlib.
Args:
surfels (list of Surfel): Each Surfel has at least:
- position: (x, y, z)
- normal: (nx, ny, nz)
- radius: scalar
- color: (R, G, B) in [0..255] (optional)
draw_normals (bool): If True, draws the surfel normals as quiver arrows.
normal_scale (float): Scale factor for the normal arrows.
disk_resolution (int): Number of segments to approximate each disk.
disk_alpha (float): Alpha (transparency) for the filled disks.
"""
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Prepare arrays for optional quiver (if draw_normals=True)
positions = []
normals = []
# We'll accumulate 3D polygons in a list for Poly3DCollection
polygons = []
polygon_colors = []
for s in surfels:
# --- Extract surfel data ---
position = s.position
normal = s.normal
radius = s.radius
if isinstance(position, torch.Tensor):
x, y, z = position.detach().cpu().numpy()
nx, ny, nz = normal.detach().cpu().numpy()
radius = radius.detach().cpu().numpy()
else:
x, y, z = position
nx, ny, nz = normal
radius = radius
# Convert color from [0..255] to [0..1], or use default
if s.color is None:
color = (0.2, 0.6, 1.0) # Light blue
else:
r, g, b = s.color
color = (r/255.0, g/255.0, b/255.0)
# --- Build local coordinate axes for the disk ---
normal = np.array([nx, ny, nz], dtype=float)
norm_len = np.linalg.norm(normal)
# Skip degenerate normals to avoid nan
if norm_len < 1e-12:
continue
normal /= norm_len
# Pick an 'up' vector that is not too close to the normal
# so we can build a tangent plane
up = np.array([0, 0, 1], dtype=float)
if abs(normal.dot(up)) > 0.9:
up = np.array([0, 1, 0], dtype=float)
# xAxis = normal x up
xAxis = np.cross(normal, up)
xAxis /= np.linalg.norm(xAxis)
# yAxis = normal x xAxis
yAxis = np.cross(normal, xAxis)
yAxis /= np.linalg.norm(yAxis)
# --- Create a circle of 'disk_resolution' segments in local 2D coords ---
angles = np.linspace(0, 2*np.pi, disk_resolution, endpoint=False)
circle_points_3d = []
for theta in angles:
# local 2D circle: (r*cosθ, r*sinθ)
px = radius * np.cos(theta)
py = radius * np.sin(theta)
# transform to 3D world space: position + px*xAxis + py*yAxis
world_pt = np.array([x, y, z]) + px * xAxis + py * yAxis
circle_points_3d.append(world_pt)
# We have a list of [x, y, z]. For a filled polygon, Poly3DCollection
# wants them as a single Nx3 array.
circle_points_3d = np.array(circle_points_3d)
polygons.append(circle_points_3d)
polygon_colors.append(color)
# Collect positions and normals for quiver (if used)
positions.append([x, y, z])
normals.append(normal)
# --- Draw the disks as polygons ---
poly_collection = Poly3DCollection(
polygons,
facecolors=polygon_colors,
edgecolors='k', # black edge
linewidths=0.5,
alpha=disk_alpha
)
ax.add_collection3d(poly_collection)
# --- Optionally draw normal vectors (quiver) ---
if draw_normals and len(positions) > 0:
X = [p[0] for p in positions]
Y = [p[1] for p in positions]
Z = [p[2] for p in positions]
Nx = [n[0] for n in normals]
Ny = [n[1] for n in normals]
Nz = [n[2] for n in normals]
# Note: If your scene is large, you may want to increase `length`.
ax.quiver(
X, Y, Z,
Nx, Ny, Nz,
length=normal_scale,
color='red',
normalize=True
)
# --- Axis labels, aspect ratio, etc. ---
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
try:
ax.set_box_aspect((1, 1, 1))
except AttributeError:
pass # older MPL versions
plt.title("Surfels as Disks (Oriented by Normal)")
plt.show()
def visualize_pointcloud(
points,
colors=None,
title='Point Cloud',
point_size=1,
alpha=1.0
):
"""
Visualize a 3D point cloud using Matplotlib, with an option to provide
per-point RGB or RGBA colors, ensuring equal scaling for the x, y, and z axes.
Parameters
----------
points : np.ndarray or torch.Tensor
A numpy array (or Tensor) of shape [N, 3] where each row is a 3D point (x, y, z).
colors : None, str, or np.ndarray
- If None, a default single color ('blue') is used.
- If a string, that color will be used for all points.
- If a numpy array, it should have shape [N, 3] or [N, 4], where each row
corresponds to the color of the matching point in `points`.
Values should be in the range [0, 1] if using floats.
title : str, optional
The title of the plot. Default is 'Point Cloud'.
point_size : float, optional
The size of the points in the scatter plot. Default is 1.
alpha : float, optional
The overall alpha (transparency) value for the points. Default is 1.0.
Examples
--------
>>> import numpy as np
>>> # Generate random points
>>> pts = np.random.rand(1000, 3)
>>> # Generate random colors in [0,1]
>>> cols = np.random.rand(1000, 3)
>>> visualize_pointcloud(pts, colors=cols, title="Random Point Cloud with Colors")
"""
# Convert Torch tensors to NumPy arrays if needed
if isinstance(points, torch.Tensor):
points = points.detach().cpu().numpy()
if isinstance(colors, torch.Tensor):
colors = colors.detach().cpu().numpy()
# Flatten points if they are in a higher-dimensional array
if len(points.shape) > 2:
points = points.reshape(-1, 3)
if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2:
colors = colors.reshape(-1, colors.shape[-1])
# Validate shape of points
if points.shape[1] != 3:
raise ValueError("`points` array must have shape [N, 3].")
# Validate or set colors
if colors is None:
colors = 'blue'
elif isinstance(colors, np.ndarray):
colors = np.asarray(colors)
if colors.shape[0] != points.shape[0]:
raise ValueError(
"Colors array length must match the number of points."
)
if colors.shape[1] not in [3, 4]:
raise ValueError(
"Colors array must have shape [N, 3] or [N, 4]."
)
# Extract coordinates
x = points[:, 0]
y = points[:, 1]
z = points[:, 2]
# Create a 3D figure
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
# Scatter plot with specified or per-point colors
ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha)
# Set labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(title)
# Ensure all axes have the same scale
max_range = np.array([x.max() - x.min(),
y.max() - y.min(),
z.max() - z.min()]).max() / 2.0
mid_x = (x.max() + x.min()) * 0.5
mid_y = (y.max() + y.min()) * 0.5
mid_z = (z.max() + z.min()) * 0.5
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
# Adjust viewing angle for better visibility
ax.view_init(elev=20., azim=30)
plt.tight_layout()
plt.show()
def visualize_depth(depth_image,
file_name="rendered_depth.png",
visualization_dir="visualization",
size=(512, 288)):
"""
Visualize a depth map as a grayscale image.
Parameters
----------
depth_image : np.ndarray
A 2D array of depth values.
visualization_dir : str
The directory to save the visualization image.
Returns
-------
PIL.Image
The visualization image.
"""
# Normalize the depth values for visualization
depth_min = depth_image.min()
depth_max = depth_image.max()
print(f"Depth min: {depth_min}, max: {depth_max}")
depth_image = np.clip(depth_image, 0, depth_max)
depth_vis = (depth_image - depth_min) / (depth_max - depth_min)
depth_vis = (depth_vis * 255).astype(np.uint8)
# Convert the depth image to a PIL image
depth_vis_img = Image.fromarray(depth_vis, mode='L')
depth_vis_img = depth_vis_img.resize(size, Image.NEAREST)
# Save the visualization image
depth_vis_img.save(os.path.join(visualization_dir, file_name))
return depth_vis_img
class Surfel:
def __init__(self, position, normal, radius=1.0, color=None):
"""
position: (x, y, z)
normal: (nx, ny, nz)
radius: scalar
color: (r, g, b) or None
"""
self.position = position
self.normal = normal
self.radius = radius
self.color = color
def __repr__(self):
return (f"Surfel(position={self.position}, "
f"normal={self.normal}, radius={self.radius}, "
f"color={self.color})")
class Octree:
def __init__(self, points, indices=None, bbox=None, max_points=10):
self.points = points
if indices is None:
indices = np.arange(points.shape[0])
self.indices = indices
if bbox is None:
min_bound = points.min(axis=0)
max_bound = points.max(axis=0)
center = (min_bound + max_bound) / 2
half_size = np.max(max_bound - min_bound) / 2
bbox = (center, half_size)
self.center, self.half_size = bbox
self.children = [] # 存储子节点
self.max_points = max_points
if len(self.indices) > self.max_points:
self.subdivide()
def subdivide(self):
cx, cy, cz = self.center
hs = self.half_size / 2
offsets = np.array([[dx, dy, dz] for dx in (-hs, hs)
for dy in (-hs, hs)
for dz in (-hs, hs)])
for offset in offsets:
child_center = self.center + offset
child_indices = []
for idx in self.indices:
p = self.points[idx]
if np.all(np.abs(p - child_center) <= hs):
child_indices.append(idx)
child_indices = np.array(child_indices)
if len(child_indices) > 0:
child = Octree(self.points, indices=child_indices, bbox=(child_center, hs), max_points=self.max_points)
self.children.append(child)
self.indices = None
def sphere_intersects_node(self, center, r):
diff = np.abs(center - self.center)
max_diff = diff - self.half_size
max_diff = np.maximum(max_diff, 0)
dist_sq = np.sum(max_diff**2)
return dist_sq <= r*r
def query_ball_point(self, point, r):
results = []
if not self.sphere_intersects_node(point, r):
return results
if len(self.children) == 0:
if self.indices is not None:
for idx in self.indices:
if np.linalg.norm(self.points[idx] - point) <= r:
results.append(idx)
return results
else:
for child in self.children:
results.extend(child.query_ball_point(point, r))
return results