PEAR / utils /camera_utils.py
BestWJH's picture
Upload 205 files
2c68f56 verified
import torch
import numpy as np
from utils.graphics_utils import get_full_proj_matrix
from tqdm import tqdm
from copy import deepcopy
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
"""
Normalize vector lengths.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
def create_cam2world_matrix(forward_vector, origin):
"""
Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
"""
forward_vector = normalize_vecs(forward_vector)
up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
right_vector = -normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
up_vector = normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
translation_matrix[:, :3, 3] = origin
cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
assert(cam2world.shape[1:] == (4, 4))
return cam2world
class LookAtPoseSampler:
"""
Same as GaussianCameraPoseSampler, except the
camera is specified as looking at 'lookat_position', a 3-vector.
Example:
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
"""
@staticmethod
def sample(horizontal_mean, vertical_mean, lookat_position,FoVx,FoVy, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cuda:0'):
h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
v = torch.clamp(v, 1e-5, np.pi - 1e-5)
theta = h
v = v / np.pi
phi = torch.arccos(1 - 2*v)
camera_origins = torch.zeros((batch_size, 3), device=device)
camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(np.pi-theta)
camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(np.pi-theta)
camera_origins[:, 1:2] = radius*torch.cos(phi)
# forward_vectors = math_utils.normalize_vecs(-camera_origins)
forward_vectors = normalize_vecs(lookat_position - camera_origins)
c2w=create_cam2world_matrix(forward_vectors, camera_origins)
w2c=torch.linalg.inv(c2w).squeeze(0)@torch.tensor([[1 ,0 ,0 ,0 ],#
[0 ,-1,0 ,0 ],
[0 ,0 ,-1,0 ],
[0 ,0 ,0 ,1 ]],dtype=torch.float32,device=device)
R = torch.transpose(w2c[:3,:3],0,1).cpu().numpy()
T= w2c[:3, 3].cpu().numpy()
c2w=torch.linalg.inv(w2c)
return w2c,c2w
def generate_novel_view_poses(tracking_info,image_size=512,tanfov=1/24.0,pitch_range = 0.3,yaw_range = 0.35,num_keyframes=120):
#pitch_range = 0.3,yaw_range = 0.35,num_keyframes=120
camera_center=tracking_info['c2w_cam'][0,:3,3]
device=tracking_info['c2w_cam'].device
circle_cam_params=[]
result_cam_params=[]
FoVx=tanfov
FoVy=tanfov
radius=camera_center.square().sum().sqrt()
lookat_position=[0.0,0.75,0.0]#[0.0,0.0,0.0] -camera_center[0].item()
print("Generate multi-view poses for rendering")
for frame_idx in tqdm(range(num_keyframes)):
w2c_cam,c2w_cam=LookAtPoseSampler.sample(
horizontal_mean=3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes)),
vertical_mean=3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / (num_keyframes)),
lookat_position=torch.Tensor(lookat_position).to(device),FoVx=FoVx,FoVy=FoVy ,radius=radius, device=device)
view_matrix,full_proj_matrix=get_full_proj_matrix(w2c_cam,tanfov)
circle_cam_params.append({
"world_view_transform":view_matrix.unsqueeze(0),"full_proj_transform":full_proj_matrix.unsqueeze(0),
'tanfovx':torch.tensor([FoVx],device=device),'tanfovy':torch.tensor([FoVy],device=device),
'image_height':torch.tensor([image_size],device=device),'image_width':torch.tensor([image_size],device=device),
'camera_center':c2w_cam[:3,3].unsqueeze(0)
})
# for idx in range(len(cams_length)):
# result_cam_params.append(deepcopy(circle_cam_params[idx%num_keyframes]))
return circle_cam_params