Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch,math | |
| from PIL import Image | |
| import torchvision | |
| from easydict import EasyDict as edict | |
| def position_produce(opt): | |
| depth_channel = opt.arch.gen.depth_arch.output_nc | |
| if opt.optim.ground_prior: | |
| depth_channel = depth_channel+1 | |
| z_ = torch.arange(depth_channel)/depth_channel | |
| x_ = torch.arange(opt.data.sat_size[1])/opt.data.sat_size[1] | |
| y_ = torch.arange(opt.data.sat_size[0])/opt.data.sat_size[0] | |
| Z,X,Y = torch.meshgrid(z_,x_,y_) | |
| input = torch.cat((Z[...,None],X[...,None],Y[...,None]),dim=-1).to(opt.device) | |
| pos = positional_encoding(opt,input) | |
| pos = pos.permute(3,0,1,2) | |
| return pos | |
| def positional_encoding(opt,input): # [B,...,N] | |
| shape = input.shape | |
| freq = 2**torch.arange(opt.arch.gen.PE_channel,dtype=torch.float32,device=opt.device)*np.pi # [L] | |
| spectrum = input[...,None]*freq # [B,...,N,L] | |
| sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] | |
| input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] | |
| input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] | |
| return input_enc | |
| def get_original_coord(opt): | |
| ''' | |
| pano_direction [X,Y,Z] x right,y up,z out | |
| ''' | |
| W,H = opt.data.pano_size | |
| _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0) | |
| _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T | |
| if opt.data.dataset in ['CVACT_Shi', 'CVACT', 'CVACThalf']: | |
| _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude | |
| elif opt.data.dataset in ['CVUSA']: | |
| _theta = (1 - 2 * (_x) / H) * np.pi/4 | |
| # _phi = math.pi* ( 1 -2* (_y)/W ) # longtitude | |
| _phi = math.pi*( - 0.5 - 2* (_y)/W ) | |
| axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(H, W, 1) | |
| axis1 = np.sin(_theta).reshape(H, W, 1) | |
| axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(H, W, 1) | |
| pano_direction = np.concatenate((axis0, axis1, axis2), axis=2) | |
| return pano_direction | |
| def render(opt,feature,voxel,pano_direction,PE=None): | |
| ''' | |
| render ground images from ssatellite images | |
| feature: B,C,H_sat,W_sat feature or a input RGB | |
| voxel: B,N,H_sat,W_sat density of each grid | |
| PE: whether add position encoding , default is None | |
| pano_direction: pano ray direction by their definition | |
| ''' | |
| # pano_W,pano_H = opt.data.pano_size | |
| sat_W,sat_H = opt.data.sat_size | |
| BS = feature.size(0) | |
| ##### get origin, sample point ,depth | |
| if opt.data.dataset =='CVACT_Shi': | |
| origin_height=2 ## the height of photo taken in real world scale | |
| realworld_scale = 30 ## the real world scale corresponding to [-1,1] regular cooridinate | |
| elif opt.data.dataset == 'CVUSA': | |
| origin_height=2 | |
| realworld_scale = 55 | |
| else: | |
| assert Exception('Not implement yet') | |
| assert sat_W==sat_H | |
| pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld | |
| if opt.data.sample_total_length: | |
| sample_total_length = opt.data.sample_total_length | |
| else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ | |
| np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2) | |
| origin_z = torch.ones([BS,1])*(-1+(origin_height/(realworld_scale/2))) ### -1 is the loweast position in regular cooridinate | |
| ##### origin_z: which can be definition by origin height | |
| if opt.origin_H_W is None: ### origin_H_W is the photo taken space in regular coordinate | |
| origin_H,origin_w = torch.zeros([BS,1]),torch.zeros([BS,1]) | |
| else: | |
| origin_H,origin_w = torch.ones([BS,1])*opt.origin_H_W[0],torch.ones([BS,1])*opt.origin_H_W[1] | |
| origin = torch.cat([origin_w,origin_z,origin_H],dim=1).to(opt.device)[:,None,None,:] ## w,z,h, samiliar to NERF coordinate definition | |
| sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device) | |
| ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number | |
| origin = origin[...,None] | |
| pano_direction = pano_direction[...,None] ### the direction has been normalized | |
| depth = sample_len[None,None,None,None,:] | |
| sample_point = origin + pano_direction * depth #0.0000],-0.8667],0.0000 w,z,h | |
| # x points right, y points up, z points backwards scene nerf | |
| # ray_depth = sample_point-origin | |
| if opt.optim.ground_prior: | |
| voxel = torch.cat([torch.ones(voxel.size(0),1,voxel.size(2),voxel.size(3),device=opt.device)*1000,voxel],1) | |
| # voxel[:,0,:,:] = 100 | |
| N = voxel.size(1) | |
| voxel_low = -1 | |
| voxel_max = -1 + opt.data.max_height/(realworld_scale/2) ### voxel highest space in normal space | |
| grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 | |
| grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space | |
| grid = grid.float() ## [1, 300, 256, 512, 3] | |
| color_input = feature.unsqueeze(2).repeat(1, 1, N, 1, 1) | |
| alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) | |
| color_grid = torch.nn.functional.grid_sample(color_input, grid) | |
| if PE is not None: | |
| PE_grid = torch.nn.functional.grid_sample(PE[None,...], grid[:1,...]) | |
| color_grid = torch.cat([color_grid,PE_grid.repeat(BS, 1, 1, 1, 1)],dim=1) | |
| depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1) | |
| feature_size = color_grid.size(1) | |
| color_grid = color_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number,feature_size) | |
| alpha_grid = alpha_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number) | |
| intv = sample_total_length/opt.data.sample_number | |
| output = composite(opt, rgb_samples=color_grid,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) | |
| output['voxel'] = voxel | |
| return output | |
| def composite(opt,rgb_samples,density_samples,depth_samples,intv): | |
| """generate 2d ground images according to ray | |
| Args: | |
| opt (_type_): option dict | |
| rgb_samples (_type_): rgb (sampled from satellite image) belongs to the ray which start from the ground camera to world | |
| density_samples (_type_): density (sampled from the predicted voxel of satellite image) belongs to the ray which start from the ground camera to world | |
| depth_samples (_type_): depth of the ray which start from the ground camera to world | |
| intv (_type_): interval of the ray's depth which start from the ground camera to world | |
| Returns: | |
| 2d ground images (rgd, opacity, and depth) | |
| """ | |
| sigma_delta = density_samples*intv # [B,HW,N] | |
| alpha = 1-(-sigma_delta).exp_() # [B,HW,N] | |
| T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] | |
| prob = (T*alpha)[...,None] # [B,HW,N,1] | |
| # integrate RGB and depth weighted by probability | |
| depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] | |
| rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3] | |
| opacity = prob.sum(dim=2) # [B,HW,1] | |
| depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0]) | |
| rgb = rgb.permute(0,2,1).view(rgb.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0]) | |
| opacity = opacity.view(opacity.size(0),1,opt.data.pano_size[1],opt.data.pano_size[0]) | |
| return {'rgb':rgb,'opacity':opacity,'depth':depth} | |
| def get_sat_ori(opt): | |
| W,H = opt.data.sat_size | |
| y_range = (torch.arange(H,dtype=torch.float32,)+0.5)/(0.5*H)-1 | |
| x_range = (torch.arange(W,dtype=torch.float32,)+0.5)/(0.5*H)-1 | |
| Y,X = torch.meshgrid(y_range,x_range) | |
| Z = torch.ones_like(Y) | |
| xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:] | |
| return xy_grid | |
| def render_sat(opt,voxel): | |
| ''' | |
| voxel: voxel has been processed | |
| ''' | |
| # pano_W,pano_H = opt.data.pano_size | |
| sat_W,sat_H = opt.data.sat_size | |
| sat_ori = get_sat_ori(opt) | |
| sat_dir = torch.tensor([0,-1,0])[None,None,None,:] | |
| ##### get origin, sample point ,depth | |
| if opt.data.dataset =='CVACT_Shi': | |
| origin_height=2 | |
| realworld_scale = 30 | |
| elif opt.data.dataset == 'CVUSA': | |
| origin_height=2 | |
| realworld_scale = 55 | |
| else: | |
| assert Exception('Not implement yet') | |
| pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld | |
| # if opt.data.sample_total_length: | |
| # sample_total_length = opt.data.sample_total_length | |
| # else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ | |
| # np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2) | |
| sample_total_length = 2 | |
| # #### sample_total_length: it can be definition in future, which is the farest length between sample point and original ponit | |
| # assert sat_W==sat_H | |
| origin = sat_ori.to(opt.device) ## w,z,h, samiliar to NERF coordinate definition | |
| sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device) | |
| ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number | |
| origin = origin[...,None].to(opt.device) | |
| direction = sat_dir[...,None].to(opt.device) ### the direction has been normalized | |
| depth = sample_len[None,None,None,None,:] | |
| sample_point = origin + direction * depth #0.0000],-0.8667],0.0000 w,z,h | |
| N = voxel.size(1) | |
| voxel_low = -1 | |
| voxel_max = -1 + opt.data.max_height/(realworld_scale/2) ### voxel highest space in normal space | |
| # axis_voxel = (torch.arange(N)/N) * (voxel_max-voxel_low) +voxel_low | |
| grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 | |
| grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space | |
| grid = grid.float() ## [1, 300, 256, 512, 3] | |
| alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) | |
| depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1) | |
| alpha_grid = alpha_grid.permute(0,3,4,2,1).view(opt.batch_size,-1,opt.data.sample_number) | |
| # color_grid = torch.flip(color_grid,[2]) | |
| # alpha_grid = torch.flip(alpha_grid,[2]) | |
| intv = sample_total_length/opt.data.sample_number | |
| output = composite_sat(opt,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) | |
| return output['opacity'],output['depth'] | |
| def composite_sat(opt,density_samples,depth_samples,intv): | |
| sigma_delta = density_samples*intv # [B,HW,N] | |
| alpha = 1-(-sigma_delta).exp_() # [B,HW,N] | |
| T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] | |
| prob = (T*alpha)[...,None] # [B,HW,N,1] | |
| depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] | |
| opacity = prob.sum(dim=2) # [B,HW,1] | |
| depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.sat_size[1],opt.data.sat_size[0]) | |
| opacity = opacity.view(opacity.size(0),1,opt.data.sat_size[1],opt.data.sat_size[0]) | |
| # return rgb,depth,opacity,prob # [B,HW,K] | |
| return {'opacity':opacity,'depth':depth} | |
| if __name__ == '__main__': | |
| # test_demo | |
| opt=edict() | |
| opt.device = 'cuda' | |
| opt.data = edict() | |
| opt.data.pano_size = [512,256] | |
| opt.data.sat_size = [256,256] | |
| opt.data.dataset = 'CVACT_Shi' | |
| opt.data.max_height = 20 | |
| opt.data.sample_number = 300 | |
| opt.arch = edict() | |
| opt.optim = edict() | |
| opt.optim.ground_prior = False | |
| opt.arch.gen.transform_mode = 'volum_rendering' | |
| # opt.arch.gen.transform_mode = 'proj_like_radus' | |
| BS = 1 | |
| opt.data.sample_total_length = 1 | |
| sat_name = './CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png' | |
| a = Image.open(sat_name) | |
| a = np.array(a).astype(np.float32) | |
| a = torch.from_numpy(a) | |
| a = a.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. | |
| pano = sat_name.replace('satview_correct','streetview').replace('_satView_polish','_grdView') | |
| pano = np.array(Image.open(pano)).astype(np.float32) | |
| pano = torch.from_numpy(pano) | |
| pano = pano.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. | |
| voxel=torch.zeros([BS, 65, 256, 256]).to(opt.device) | |
| pano_direction = torch.from_numpy(get_original_coord(opt)).unsqueeze(0).to(opt.device) | |
| import time | |
| star = time.time() | |
| with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=False) as prof: | |
| rgb,opacity =render(opt,a,voxel,pano_direction) | |
| print(prof.table()) | |
| print(time.time()-star) | |
| torchvision.utils.save_image(torch.cat([rgb,pano],2), opt.arch.gen.transform_mode + '.png') | |
| print( opt.arch.gen.transform_mode + '.png') | |
| torchvision.utils.save_image(opacity, 'opa.png') |