Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| def coords_grid(b, h, w, homogeneous=False, device=None): | |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] | |
| stacks = [x, y] | |
| if homogeneous: | |
| ones = torch.ones_like(x) # [H, W] | |
| stacks.append(ones) | |
| grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] | |
| grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] | |
| if device is not None: | |
| grid = grid.to(device) | |
| return grid | |
| def warp_with_pose_depth_candidates( | |
| feature1, | |
| intrinsics, | |
| pose, | |
| depth, | |
| clamp_min_depth=1e-3, | |
| grid_sample_disable_cudnn=False, | |
| ): | |
| """ | |
| feature1: [B, C, H, W] | |
| intrinsics: [B, 3, 3] | |
| pose: [B, 4, 4] | |
| depth: [B, D, H, W] | |
| """ | |
| assert intrinsics.size(1) == intrinsics.size(2) == 3 | |
| assert pose.size(1) == pose.size(2) == 4 | |
| assert depth.dim() == 4 | |
| b, d, h, w = depth.size() | |
| c = feature1.size(1) | |
| with torch.no_grad(): | |
| # pixel coordinates | |
| grid = coords_grid( | |
| b, h, w, homogeneous=True, device=depth.device | |
| ) # [B, 3, H, W] | |
| # back project to 3D and transform viewpoint | |
| points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] | |
| points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( | |
| 1, 1, d, 1 | |
| ) * depth.view( | |
| b, 1, d, h * w | |
| ) # [B, 3, D, H*W] | |
| points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] | |
| # reproject to 2D image plane | |
| points = torch.bmm(intrinsics, points.view(b, 3, -1)).view( | |
| b, 3, d, h * w | |
| ) # [B, 3, D, H*W] | |
| pixel_coords = points[:, :2] / points[:, -1:].clamp( | |
| min=clamp_min_depth | |
| ) # [B, 2, D, H*W] | |
| # normalize to [-1, 1] | |
| x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 | |
| y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 | |
| grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] | |
| # sample features | |
| # ref: https://github.com/pytorch/pytorch/issues/88380 | |
| # print(feature1.shape, grid.shape) | |
| # hardcoded workaround | |
| if feature1.numel() > 1000000: | |
| grid_sample_disable_cudnn = True | |
| with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn): | |
| warped_feature = F.grid_sample( | |
| feature1, | |
| grid.view(b, d * h, w, 2), | |
| mode="bilinear", | |
| padding_mode="zeros", | |
| align_corners=True, | |
| ).view( | |
| b, c, d, h, w | |
| ) # [B, C, D, H, W] | |
| return warped_feature | |