| | import torch |
| | from torch.nn.functional import grid_sample |
| |
|
| |
|
| | def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False, |
| | with_proj_z=False): |
| | |
| | ''' |
| | Unproject the image fetures to form a 3D (sparse) feature volume |
| | |
| | :param coords: coordinates of voxels, |
| | dim: (num of voxels, 4) (4 : batch ind, x, y, z) |
| | :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) |
| | dim: (batch size, 3) (3: x, y, z) |
| | :param voxel_size: floats specifying the size of a voxel |
| | :param feats: image features |
| | dim: (num of views, batch size, C, H, W) |
| | :param KRcam: projection matrix |
| | dim: (num of views, batch size, 4, 4) |
| | :return: feature_volume_all: 3D feature volumes |
| | dim: (num of voxels, num_of_views, c) |
| | :return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not |
| | dim: (num of voxels, num_of_views) |
| | ''' |
| | n_views, bs, c, h, w = feats.shape |
| | device = feats.device |
| |
|
| | if sizeH is None: |
| | sizeH, sizeW = h, w |
| |
|
| | feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device) |
| | mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device) |
| | |
| | for batch in range(bs): |
| | |
| | batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1) |
| | coords_batch = coords[batch_ind][:, 1:] |
| |
|
| | coords_batch = coords_batch.view(-1, 3) |
| | origin_batch = origin[batch].unsqueeze(0) |
| | feats_batch = feats[:, batch] |
| | proj_batch = KRcam[:, batch] |
| |
|
| | grid_batch = coords_batch * voxel_size + origin_batch.float() |
| | rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1) |
| | rs_grid = rs_grid.permute(0, 2, 1).contiguous() |
| | nV = rs_grid.shape[-1] |
| | rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) |
| |
|
| | |
| | im_p = proj_batch @ rs_grid |
| | im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] |
| |
|
| | im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6) |
| |
|
| | im_x = im_x / im_z |
| | im_y = im_y / im_z |
| |
|
| | im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) |
| | mask = im_grid.abs() <= 1 |
| | mask = (mask.sum(dim=-1) == 2) & (im_z > 0) |
| |
|
| | mask = mask.view(n_views, -1) |
| | mask = mask.permute(1, 0).contiguous() |
| |
|
| | mask_volume_all[batch_ind] = mask.to(torch.int32) |
| |
|
| | if only_mask: |
| | return mask_volume_all |
| |
|
| | feats_batch = feats_batch.view(n_views, c, h, w) |
| | im_grid = im_grid.view(n_views, 1, -1, 2) |
| | features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True) |
| | |
| | |
| | features = features.view(n_views, c, -1) |
| | features = features.permute(2, 0, 1).contiguous() |
| |
|
| | feature_volume_all[batch_ind] = features |
| |
|
| | if with_proj_z: |
| | im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() |
| | return feature_volume_all, mask_volume_all, im_z |
| | |
| | |
| | return feature_volume_all, mask_volume_all |
| |
|
| |
|
| | def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False): |
| | """Transform coordinates in the camera frame to the pixel frame. |
| | Args: |
| | cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] |
| | proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3] |
| | proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] |
| | Returns: |
| | array of [-1,1] coordinates -- [B, H, W, 2] |
| | """ |
| | b, _, h, w = cam_coords.size() |
| | if sizeH is None: |
| | sizeH = h |
| | sizeW = w |
| |
|
| | cam_coords_flat = cam_coords.view(b, 3, -1) |
| | if proj_c2p_rot is not None: |
| | pcoords = proj_c2p_rot.bmm(cam_coords_flat) |
| | else: |
| | pcoords = cam_coords_flat |
| |
|
| | if proj_c2p_tr is not None: |
| | pcoords = pcoords + proj_c2p_tr |
| | X = pcoords[:, 0] |
| | Y = pcoords[:, 1] |
| | Z = pcoords[:, 2].clamp(min=1e-3) |
| |
|
| | X_norm = 2 * (X / Z) / (sizeW - 1) - 1 |
| | |
| | Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 |
| | if padding_mode == 'zeros': |
| | X_mask = ((X_norm > 1) + (X_norm < -1)).detach() |
| | X_norm[X_mask] = 2 |
| | Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach() |
| | Y_norm[Y_mask] = 2 |
| |
|
| | if with_depth: |
| | pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) |
| | return pixel_coords.view(b, h, w, 3) |
| | else: |
| | pixel_coords = torch.stack([X_norm, Y_norm], dim=2) |
| | return pixel_coords.view(b, h, w, 2) |
| |
|
| |
|
| | |
| | def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None): |
| | ''' |
| | Unproject the image fetures to form a 3D (dense) feature volume |
| | |
| | :param coords: coordinates of voxels, |
| | dim: (batch, nviews, 3, X,Y,Z) |
| | :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) |
| | dim: (batch size, 3) (3: x, y, z) |
| | :param voxel_size: floats specifying the size of a voxel |
| | :param feats: image features |
| | dim: (batch size, num of views, C, H, W) |
| | :param proj_matrix: projection matrix |
| | dim: (batch size, num of views, 4, 4) |
| | :return: feature_volume_all: 3D feature volumes |
| | dim: (batch, nviews, C, X,Y,Z) |
| | :return: count: number of times each voxel can be seen |
| | dim: (batch, nviews, 1, X,Y,Z) |
| | ''' |
| |
|
| | batch, nviews, _, wX, wY, wZ = coords.shape |
| |
|
| | if sizeH is None: |
| | sizeH, sizeW = feats.shape[-2:] |
| | proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:]) |
| |
|
| | coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1) |
| | coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) |
| |
|
| | pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], |
| | 'zeros', sizeH=sizeH, sizeW=sizeW) |
| | pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2) |
| |
|
| | feats = feats.view(batch * nviews, *feats.shape[2:]) |
| |
|
| | ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device) |
| |
|
| | features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True) |
| | counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True) |
| |
|
| | features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) |
| | counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ) |
| | return features_volume, counts_volume |
| |
|
| |
|