import torch import torch.nn.functional as F def coords_grid(b, h, w, homogeneous=False, device=None): y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] stacks = [x, y] if homogeneous: ones = torch.ones_like(x) # [H, W] stacks.append(ones) grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] if device is not None: grid = grid.to(device) return grid def warp_with_pose_depth_candidates( feature1, intrinsics, pose, depth, clamp_min_depth=1e-3, grid_sample_disable_cudnn=False, ): """ feature1: [B, C, H, W] intrinsics: [B, 3, 3] pose: [B, 4, 4] depth: [B, D, H, W] """ assert intrinsics.size(1) == intrinsics.size(2) == 3 assert pose.size(1) == pose.size(2) == 4 assert depth.dim() == 4 b, d, h, w = depth.size() c = feature1.size(1) with torch.no_grad(): # pixel coordinates grid = coords_grid( b, h, w, homogeneous=True, device=depth.device ) # [B, 3, H, W] # back project to 3D and transform viewpoint points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( 1, 1, d, 1 ) * depth.view( b, 1, d, h * w ) # [B, 3, D, H*W] points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] # reproject to 2D image plane points = torch.bmm(intrinsics, points.view(b, 3, -1)).view( b, 3, d, h * w ) # [B, 3, D, H*W] pixel_coords = points[:, :2] / points[:, -1:].clamp( min=clamp_min_depth ) # [B, 2, D, H*W] # normalize to [-1, 1] x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] # sample features # ref: https://github.com/pytorch/pytorch/issues/88380 # print(feature1.shape, grid.shape) # hardcoded workaround if feature1.numel() > 1000000: grid_sample_disable_cudnn = True with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn): warped_feature = F.grid_sample( feature1, grid.view(b, d * h, w, 2), mode="bilinear", padding_mode="zeros", align_corners=True, ).view( b, c, d, h, w ) # [B, C, D, H, W] return warped_feature