Sat3DGen / source /rendering /sat2density_transform_eg3d.py
qian43's picture
Upload 150 files
3ed17f3 verified
import numpy as np
import torch,math
from PIL import Image
import torchvision
from easydict import EasyDict as edict
import torch.nn.functional as F
import torch.nn as nn
import random
from einops import repeat, rearrange
from source.rendering.point_sampler import perturb_points_per_ray
from source.rendering.aabb import intersect_aabb_end
from source.rendering.transform_perspective import compose_rotmat
def get_normal_coord(W, H, device='cpu'):
'''
Standard equirectangular panorama coordinate normalization.
W: panorama width
H: panorama height
device: target device, usually `cpu` or `cuda`
Returns:
normalized_coords: tensor with shape (W, H, 3)
'''
# Create linear coordinates from 0 to W-1 and 0 to H-1.
y = torch.linspace(0, W - 1, W, device=device)
x = torch.linspace(0, H - 1, H, device=device)
# Build the mesh grid.
Y, X = torch.meshgrid(y, x, indexing='ij') # Y: (W, H), X: (W, H)
# Convert the grid to longitude and latitude.
phi = -(Y / (W - 1) - 0.5) * 2 * math.pi + (math.pi / 2) # longitude in [-pi, pi]
theta = -(0.5 - X / (H - 1)) * math.pi # latitude in [-pi/2, pi/2]
# Compute normalized 3D coordinates.
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
normalized_coords = torch.stack([
cos_theta * cos_phi, # x axis
sin_theta, # y axis
cos_theta * sin_phi # z axis
], dim=2) # Shape: (W, H, 3)
# reshape to (H, W, 3)
normalized_coords = normalized_coords.permute(1, 0, 2)
return normalized_coords
def get_original_coord(W,H,full=True,c2w=None):
'''
W: width of pano
H: height of pano
if dataset is CVACT, ful=True, return the original coordinate of CVACT
if dataset is CVUSA, ful=False,
fill = False only used for CVUSA dataset
'''
normalized_coords = get_normal_coord(W,H)
if c2w is None:
RollPitchYaw = [0,0,0]
R_c2w = compose_rotmat(RollPitchYaw[0], RollPitchYaw[1], RollPitchYaw[2])
# to torch and then to the devidece of normalized_coords
# if numpy, to torch
if isinstance(R_c2w, np.ndarray):
R_c2w = torch.from_numpy(R_c2w).to(normalized_coords.device).float()
ray_directions = torch.einsum('ij,hwj->hwi', R_c2w, normalized_coords)
# ray_directions = np.einsum('ij,hwj->hwi', R_c2w, normalized_coords) # [H, W, 3]
# Normalize ray directions by torch ops
ray_directions = ray_directions / torch.norm(ray_directions, dim=-1, keepdim=True)
# ray_directions = ray_directions / np.linalg.norm(ray_directions, axis=-1, keepdims=True)
return ray_directions
class Point_sampler_pano(torch.nn.Module):
# designed for street view panorama image
def __init__(self,
pano_direction,
sample_total_length=None,
num_points=300,
perturbation_strategy = 'uniform',
aabb_strict =False,
data_type = None,
):
super().__init__()
self.sample_total_length = np.sqrt(1.5**2+1.5**2+1.9**2)
self.pano_direction = pano_direction
self.num_points = num_points
if not aabb_strict:
self.sample_len = ((torch.arange(self.num_points)+1)*(self.sample_total_length/self.num_points)).float()
self.voxel_low = -1
self.voxel_max = 1
self.perturbation_strategy = perturbation_strategy
self.aabb_strict = aabb_strict
@torch.no_grad()
def forward(self,
batch_size,
position=None,
):
device = position.device
origin_opensfm = position[:,None,None,:].to(device) # w -h z
pano_direction = self.pano_direction[...,None].to(device) # b h w c # in opensfm coordinate
output = edict()
H,W = pano_direction.shape[1],pano_direction.shape[2]
rays_world = repeat(pano_direction, '1 h w c 1 -> b h w c', b=batch_size )
ray_origins = repeat(origin_opensfm, 'b 1 1 c -> b h w c', h=H, w=W )
if self.aabb_strict:
sample_total_length = intersect_aabb_end(ray_origins,rays_world,min=0,max=self.sample_total_length)
sample_total_length = rearrange(sample_total_length, '(b h w) -> b h w 1', b=batch_size, h = H, w = W )
output.radii_raw = (torch.arange(self.num_points)+1)[None,None,None,:].to(sample_total_length.device) * (sample_total_length/self.num_points)
else:
depth = self.sample_len[None,None,None,None,:]
output.radii_raw = repeat(depth, '1 1 1 1 k -> b h w k', b=batch_size, h = H, w = W ) # (1,h,w,1,k) -> (batch_size, h, w, k)
output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy)
sample_point = ray_origins.unsqueeze(-1) + rays_world.unsqueeze(-1) * output.radii.unsqueeze(-2)
output.points_world = rearrange(sample_point, 'b h w c k -> b h w k c').clone()
output.ray_origins = ray_origins.clone()
output.ray_origins[...,1] = -output.ray_origins[...,1]
output.rays_world = rays_world.clone()
output.rays_world[...,1] = -output.rays_world[...,1]
output.points_world[...,1] = -output.points_world[...,1]
return output
def get_sat_ori(resolution,position_scale_factor=1):
y_range = (torch.arange(resolution,dtype=torch.float32,)+0.5)/(0.5*resolution)-1
x_range = (torch.arange(resolution,dtype=torch.float32,)+0.5)/(0.5*resolution)-1
Y,X = torch.meshgrid(y_range,x_range)
Y = Y*position_scale_factor
X = X*position_scale_factor
Z = torch.ones_like(Y) # z=1 means the highest position in coordinate, in dimension 1
xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:]
return xy_grid
class Point_sampler_ortho(torch.nn.Module):
'''
point sampler designed for ortho image,
'''
def __init__(self,
num_points,
resolution=256,
perturbation_strategy = 'uniform',
position_scale_factor = 1,
render_size = 128,
):
super().__init__()
self.perturbation_strategy = perturbation_strategy
# not used any more
self.resolution = resolution
self.sat_ori = get_sat_ori(self.resolution,position_scale_factor)[...,None]
self.sat_dir = torch.tensor([0,-1,0])[None,None,None,:,None]
self.sample_total_length = 2
self.num_points = num_points
self.sample_len = ((torch.arange(self.num_points)+1)*(self.sample_total_length/self.num_points))
self.render_size = render_size
@torch.no_grad()
def forward(self,
batch_size,
random_crop=True,
crop_type=None,
):
device = self.sat_ori.device if self.sat_ori.is_cuda else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
depth = self.sample_len[None,None,None,None,:].to(device).float()
sat_dir = self.sat_dir.to(device)
# sample_point = self.sat_ori + self.sat_dir * depth
output = edict()
if random_crop:
if crop_type == 'crop':
assert self.render_size < self.resolution
start_h = random.randint(0,self.resolution-self.render_size-1)
start_w = random.randint(0,self.resolution-self.render_size-1)
output.idx = [start_h,start_w]
sat_ori = self.sat_ori[:,start_h:start_h+self.render_size,start_w:start_w+self.render_size,:]
elif crop_type == 'resize':
sat_ori = rearrange(self.sat_ori,'b h w c 1 -> b c h w')
sat_ori = F.interpolate(sat_ori,scale_factor=0.5,mode='bilinear')
sat_ori = rearrange(sat_ori,'b c h w -> b h w c 1')
else:
raise NotImplementedError
else:
sat_ori = self.sat_ori
self.render_size = self.resolution
assert self.render_size == self.resolution
sat_ori = sat_ori.to(device)
# sat_ori = self.position_scale_factor * sat_ori
# output.points_world = repeat(grid, '1 k h w c -> b h w k c', b=batch_size)
output.rays_world = repeat(sat_dir, '1 1 1 c 1 -> b h w c', b=batch_size, h = self.render_size, w = self.render_size )[...,[0,2,1]] # (1,h,w,3,1) -> (batch_size, h, w, 3)
output.radii_raw = repeat(depth, '1 1 1 1 k -> b h w k', b=batch_size, h = self.render_size, w = self.render_size ) # (1,h,w,1,k) -> (batch_size, h, w, k, 1)
output.ray_origins = repeat(sat_ori, '1 h w c 1 -> b h w c',b=batch_size)[...,[0,2,1]]
output.radii = perturb_points_per_ray(output.radii_raw,strategy=self.perturbation_strategy)
sample_point = sat_ori + sat_dir * output.radii.unsqueeze(-2)
grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] # has a change back, from height in the second dimension to height in the last dimension
# grid[...,2] = ((grid[...,2]-self.voxel_low)/(self.voxel_max-self.voxel_low))*2-1
# grid = grid.float()
output.points_world = rearrange(grid, 'b k h w c -> b h w k c')
return output
# class RGB_Reprerenter(torch.nn.Module):
# def __init__(self,
# ):
# super().__init__()
# def forward(self,
# points,
# image,
# ):
# point_h_w = points[...,0:2].unsqueeze(2) # b, N, 1, 2
# rgb_feature = F.grid_sample(image,point_h_w).squeeze(-1).permute(0,2,1) # b, C, N, 1 -> b, N, C
# return rgb_feature