Sat3DGen / source /rendering /transform_perspective.py
qian43's picture
Upload 150 files
3ed17f3 verified
import numpy as np
from scipy.spatial.transform import Rotation
import torch
from einops import repeat, rearrange
from easydict import EasyDict as edict
import torch.nn.functional as F
from source.rendering.aabb import intersect_aabb_end
from source.rendering.point_sampler import perturb_points_per_ray
def decompose_rotmat(R_c2w):
R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
return roll, pitch, yaw
def normalize_angles(angles):
"""Normalize angles to be within the range [-180, 180] degrees."""
return (np.array(angles) + 180) % 360 - 180
def compose_rotmat(roll, pitch, yaw):
R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
rot_c2w = Rotation.from_euler("YXZ", [roll, pitch, yaw], degrees=True)
rot_w2c = R_cv2xyz.inv() * rot_c2w
return rot_w2c.inv().as_matrix()
def fov_size2intrinsics(fov, img_size):
"""Converts field of view size to camera intrinsics."""
if isinstance(fov, (int, float)):
fov = [fov, fov]
fov_x = np.deg2rad(fov[0])
fov_y = np.deg2rad(fov[1])
fx = (img_size[0] / 2) / np.tan(fov_x / 2)
fy = (img_size[1] / 2) / np.tan(fov_y / 2)
return np.array([[fx, 0, img_size[0] / 2],
[0, fy, img_size[1] / 2],
[0, 0, 1]])
def from_Euler_and_position_to_c2w(roll_pitch_yaw, position):
roll, pitch, yaw = roll_pitch_yaw
rot_c2w = compose_rotmat(roll, pitch, yaw)
T_c2w = np.eye(4)
T_c2w[:3, :3] = rot_c2w
if isinstance(position, torch.Tensor):
position = position.cpu().numpy()
T_c2w[:3, 3] = position
return T_c2w
class PointSamplerPerspective(torch.nn.Module):
def __init__(self, num_points,aabb_strict=True,perturbation_strategy = 'uniform',render_size=[128,128]):
super().__init__()
"""
render_size: [H,W]
num_points: number of points to sample along each ray
aabb_strict: whether to use strict AABB for sampling
perturbation_strategy: strategy for perturbing points along the ray
"""
self.aabb_strict = aabb_strict,
self.sample_total_length = np.sqrt(1.5**2+1.5**2+1.9**2)
self.num_points = num_points
self.perturbation_strategy = perturbation_strategy
self.render_size = render_size
@torch.no_grad()
def forward(self, intrinsics, c2w):
# c2w: B x 4 x 4
# intrinsics: B x 3 x 3
# return:
# output = edict()
# output.rays_world: B x H x W x C # direction of the rays
# output.radii_raw: B x H x W x K
# output.radii: B x H x W x K
# output.ray_origins: B x H x W x C # origin of the rays
# output.points_world: B x H x W x K x C
batch_size = c2w.shape[0]
t = c2w[:, :3, 3].clone()
output = edict()
device = c2w.device
output.ray_origins = repeat(t, 'b c -> b h w c', h=self.render_size[0] , w=self.render_size[1]).to(device)
output.ray_origins = output.ray_origins.clone() # w -h z
output.rays_world = compute_ray_directions(c2w.to(device), intrinsics.to(device), self.render_size[0], self.render_size[1])
if self.aabb_strict:
# from b c to (b h w) c
# origin_for_aabb = repeat(output.ray_origins, 'b c -> b h w c', h = H, w = W)
# from b h w c to (b h w) c
# pano_direction_for_aabb = repeat(output.rays_world, 'b h w c -> b h w c', h = H, w = W)
sample_total_length = intersect_aabb_end(output.ray_origins,output.rays_world,min=0,max=self.sample_total_length)
sample_total_length = rearrange(sample_total_length, '(b h w) -> b h w 1', b=batch_size, h = self.render_size[0], w = self.render_size[1] )
output.radii_raw = (torch.arange(self.num_points)+1)[None,None,None,:].to(sample_total_length.device) * (sample_total_length/self.num_points)
else:
raise NotImplementedError
output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy)
sample_point = output.ray_origins.unsqueeze(-1) + output.rays_world.unsqueeze(-1) * output.radii.unsqueeze(-2)
output.points_world = rearrange(sample_point, 'b h w c k -> b h w k c')
# process_from w -h z to w h z
output.ray_origins[...,1] = -output.ray_origins[...,1]
output.rays_world[...,1] = -output.rays_world[...,1]
output.points_world[...,1] = -output.points_world[...,1]
return output
# return output
def generate_pixel_coordinates(H, W):
"""
Generate pixel coordinates grid on the image plane.
Parameters:
- H: Image height
- W: Image width
Returns:
- pixel_coords: Pixel coordinates grid with shape [H, W, 3]
"""
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
# to current device
pixel_coords = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().to(torch.cuda.current_device() if torch.cuda.is_available() else 'cpu')
return pixel_coords
def compute_ray_directions(camera2world, intrinsics, H, W):
"""
Compute ray directions.
Parameters:
- camera2world: Camera-to-world transformation matrix with shape [B, 4, 4]
- intrinsics: Intrinsic matrix with shape [B, n, n]
Returns:
- ray_directions: Ray directions with shape [B, H, W, 3]
"""
B = camera2world.shape[0]
pixel_coords = generate_pixel_coordinates(H, W) # [H, W, 3]
pixel_coords = pixel_coords.unsqueeze(0).expand(B, -1, -1, -1) # [B, H, W, 3]
inv_intrinsics = torch.inverse(intrinsics) # [B, n, n]
normalized_coords = torch.einsum('bij,bhwj->bhwi', inv_intrinsics, pixel_coords) # [B, H, W, 3]
ray_directions = torch.einsum('bij,bhwj->bhwi', camera2world[:, :3, :3], normalized_coords) # [B, H, W, 3]
ray_directions = ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True)
return ray_directions