Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| def coords_grid(b, h, w, homogeneous=False, device=None): | |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] | |
| stacks = [x, y] | |
| if homogeneous: | |
| ones = torch.ones_like(x) # [H, W] | |
| stacks.append(ones) | |
| grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] | |
| grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] | |
| if device is not None: | |
| grid = grid.to(device) | |
| return grid | |
| def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): | |
| assert device is not None | |
| x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), | |
| torch.linspace(h_min, h_max, len_h, device=device)], | |
| ) | |
| grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] | |
| return grid | |
| def normalize_coords(coords, h, w): | |
| # coords: [B, H, W, 2] | |
| c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) | |
| return (coords - c) / c # [-1, 1] | |
| def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): | |
| # img: [B, C, H, W] | |
| # sample_coords: [B, 2, H, W] in image scale | |
| if sample_coords.size(1) != 2: # [B, H, W, 2] | |
| sample_coords = sample_coords.permute(0, 3, 1, 2) | |
| b, _, h, w = sample_coords.shape | |
| # Normalize to [-1, 1] | |
| x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 | |
| y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 | |
| grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] | |
| img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) | |
| if return_mask: | |
| mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] | |
| return img, mask | |
| return img | |
| def flow_warp(feature, flow, mask=False, padding_mode='zeros'): | |
| b, c, h, w = feature.size() | |
| assert flow.size(1) == 2 | |
| grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] | |
| return bilinear_sample(feature, grid, padding_mode=padding_mode, | |
| return_mask=mask) | |
| def forward_backward_consistency_check(fwd_flow, bwd_flow, | |
| alpha=0.01, | |
| beta=0.5 | |
| ): | |
| # fwd_flow, bwd_flow: [B, 2, H, W] | |
| # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) | |
| assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 | |
| assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 | |
| flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] | |
| warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] | |
| warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] | |
| diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] | |
| diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) | |
| threshold = alpha * flow_mag + beta | |
| fwd_occ = (diff_fwd > threshold).float() # [B, H, W] | |
| bwd_occ = (diff_bwd > threshold).float() | |
| return fwd_occ, bwd_occ | |
| def back_project(depth, intrinsics): | |
| # Back project 2D pixel coords to 3D points | |
| # depth: [B, H, W] | |
| # intrinsics: [B, 3, 3] | |
| b, h, w = depth.shape | |
| grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] | |
| intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] | |
| points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] | |
| return points | |
| def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): | |
| # Transform 3D points from reference camera to target camera | |
| # points_ref: [B, 3, H, W] | |
| # extrinsics_ref: [B, 4, 4] | |
| # extrinsics_tgt: [B, 4, 4] | |
| # extrinsics_rel: [B, 4, 4], relative pose transform | |
| b, _, h, w = points_ref.shape | |
| if extrinsics_rel is None: | |
| extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] | |
| points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], | |
| points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] | |
| points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] | |
| return points_tgt | |
| def reproject(points_tgt, intrinsics, return_mask=False): | |
| # reproject to target view | |
| # points_tgt: [B, 3, H, W] | |
| # intrinsics: [B, 3, 3] | |
| b, _, h, w = points_tgt.shape | |
| proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] | |
| X = proj_points[:, 0] | |
| Y = proj_points[:, 1] | |
| Z = proj_points[:, 2].clamp(min=1e-3) | |
| pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale | |
| if return_mask: | |
| # valid mask in pixel space | |
| mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( | |
| pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] | |
| return pixel_coords, mask | |
| return pixel_coords | |
| def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, | |
| return_mask=False): | |
| # Compute reprojection sample coords | |
| points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] | |
| points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) | |
| if return_mask: | |
| reproj_coords, mask = reproject(points_tgt, intrinsics, | |
| return_mask=return_mask) # [B, 2, H, W] in image scale | |
| return reproj_coords, mask | |
| reproj_coords = reproject(points_tgt, intrinsics, | |
| return_mask=return_mask) # [B, 2, H, W] in image scale | |
| return reproj_coords | |
| def compute_flow_with_depth_pose(depth_ref, intrinsics, | |
| extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, | |
| return_mask=False): | |
| b, h, w = depth_ref.shape | |
| coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] | |
| if return_mask: | |
| reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, | |
| extrinsics_rel=extrinsics_rel, | |
| return_mask=return_mask) # [B, 2, H, W] | |
| rigid_flow = reproj_coords - coords_init | |
| return rigid_flow, mask | |
| reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, | |
| extrinsics_rel=extrinsics_rel, | |
| return_mask=return_mask) # [B, 2, H, W] | |
| rigid_flow = reproj_coords - coords_init | |
| return rigid_flow | |
| def forward_backward_consistency_check(fwd_flow, bwd_flow, | |
| alpha=0.01, | |
| beta=0.5, | |
| return_flow_diff=False, | |
| ): | |
| # fwd_flow, bwd_flow: [B, 2, H, W] | |
| # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) | |
| assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 | |
| assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 | |
| flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] | |
| warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] | |
| warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] | |
| diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] | |
| diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) | |
| if return_flow_diff: | |
| return diff_fwd, diff_bwd | |
| threshold = alpha * flow_mag + beta | |
| fwd_occ = (diff_fwd > threshold).float() # [B, H, W] | |
| bwd_occ = (diff_bwd > threshold).float() | |
| return fwd_occ, bwd_occ | |
| def warp_with_depth_pose(feature1, intrinsics, pose, depth, | |
| padding_mode='zeros', | |
| return_rigid_flow=False, | |
| return_mask=False, | |
| ): | |
| assert depth.dim() == 3 # [B, H, W] | |
| sample_coords = reproject_coords(depth, | |
| intrinsics, | |
| extrinsics_rel=pose, | |
| return_mask=return_mask, | |
| ) # [B, 2, H, W] | |
| if return_mask: | |
| sample_coords, mask = sample_coords | |
| sample_coords = sample_coords.permute(0, 2, 3, 1) # [B, H, W, 2] | |
| warped_feature1 = bilinear_sample(feature1, sample_coords, | |
| padding_mode=padding_mode) # [B, C, H, W] | |
| if return_mask: | |
| return warped_feature1, mask | |
| if return_rigid_flow: | |
| b, h, w = depth.size() | |
| coords_init = coords_grid(b, h, w, device=depth.device) # [B, 2, H, W] | |
| rigid_flow = sample_coords.permute(0, 3, 1, 2) - coords_init | |
| return warped_feature1, rigid_flow | |
| return warped_feature1 | |
| def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, | |
| padding_mode='zeros', | |
| rigid_flow_to_subtract=None, | |
| ): | |
| # pixel-specific depth candidates, useful for refinement | |
| # feature1: [B, C, H, W] | |
| # intrinsics: [B, 3, 3] | |
| # pose: [B, 4, 4] | |
| # depth: [B, D, H, W] | |
| assert intrinsics.size(1) == intrinsics.size(2) == 3 | |
| assert pose.size(1) == pose.size(2) == 4 | |
| assert depth.dim() == 4 | |
| b, d, h, w = depth.size() | |
| c = feature1.size(1) | |
| # stop gradient | |
| with torch.no_grad(): | |
| # pixel coordinates | |
| grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] | |
| # back project to 3D and transform viewpoint | |
| points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] | |
| points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( | |
| 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] | |
| points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] | |
| # reproject to 2D image plane | |
| points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] | |
| pixel_coords = points[:, :2] / points[:, -1:].clamp(min=MIN_DEPTH) # [B, 2, D, H*W] | |
| if rigid_flow_to_subtract is not None: | |
| assert rigid_flow_to_subtract.dim() == 4 # [B, 2, H, W] | |
| assert rigid_flow_to_subtract.size(1) == 2 | |
| pixel_coords = pixel_coords - rigid_flow_to_subtract.view(b, 2, h * w).unsqueeze(2) | |
| # normalize to [-1, 1] | |
| x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 | |
| y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 | |
| grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] | |
| # sample features | |
| warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', | |
| padding_mode=padding_mode, | |
| align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] | |
| return warped_feature | |