Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| def estimate_surface_normal( | |
| points: torch.Tensor, d: int = 2, mode: str = "closest" | |
| ) -> torch.Tensor: | |
| # estimate surface normal from coordinated point clouds | |
| # re-implemented the following codes with pytorch: | |
| # https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py | |
| # https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py | |
| assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}" | |
| B, C, H, W = points.shape | |
| assert C == 3, f"expected C==3, but got {C}" | |
| device = points.device | |
| # points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf")) | |
| points = F.pad(points, (0, 0, d, d), mode="replicate") | |
| points = F.pad(points, (d, d, 0, 0), mode="circular") | |
| points = points.permute(0, 2, 3, 1) # (B,H,W,3) | |
| # 8 adjacent offsets | |
| # ----------- | |
| # | 7 | 6 | 5 | | |
| # ----------- | |
| # | 0 | | 4 | | |
| # ----------- | |
| # | 1 | 2 | 3 | | |
| # ----------- | |
| offsets = torch.tensor( | |
| [ | |
| # (dh,dw) | |
| (-d, 0), # 0 | |
| (-d, d), # 1 | |
| (0, d), # 2 | |
| (d, d), # 3 | |
| (d, 0), # 4 | |
| (d, -d), # 5 | |
| (0, -d), # 6 | |
| (-d, -d), # 7 | |
| ], | |
| device=device, | |
| ) | |
| # (B,H,W) indices | |
| b = torch.arange(B, device=device)[:, None, None] | |
| h = torch.arange(H, device=device)[None, :, None] | |
| w = torch.arange(W, device=device)[None, None, :] | |
| k = torch.arange(8, device=device) | |
| # anchor points | |
| b1 = b[:, None] # (B,1,1,1) | |
| h1 = h[:, None] + d # (1,1,H,1) | |
| w1 = w[:, None] + d # (1,1,1,W) | |
| anchors = points[b1, h1, w1] # (B,H,W,3) -> (B,1,H,W,3) | |
| # neighbor points | |
| offset = offsets[k] # (8,2) | |
| b2 = b1 | |
| h2 = h1 + offset[None, :, 0, None, None] # (1,8,H,1) | |
| w2 = w1 + offset[None, :, 1, None, None] # (1,8,1,W) | |
| points1 = points[b2, h2, w2] # (B,8,H,W,3) | |
| # anothor neighbor points | |
| offset = offsets[(k + 2) % 8] | |
| b3 = b1 | |
| h3 = h1 + offset[None, :, 0, None, None] | |
| w3 = w1 + offset[None, :, 1, None, None] | |
| points2 = points[b3, h3, w3] # (B,8,H,W,3) | |
| if mode == "closest": | |
| # find the closest neighbor pair | |
| diff = torch.norm(points1 - anchors, dim=4) | |
| diff = diff + torch.norm(points2 - anchors, dim=4) | |
| i = torch.argmin(diff, dim=1) # (B,H,W) | |
| # get normals by cross product | |
| anchors = anchors[b, 0, h, w] # (B,H,W,3) | |
| points1 = points1[b, i, h, w] # (B,H,W,3) | |
| points2 = points2[b, i, h, w] # (B,H,W,3) | |
| vector1 = points1 - anchors | |
| vector2 = points2 - anchors | |
| normals = torch.cross(vector1, vector2, dim=-1) # (B,H,W,3) | |
| elif mode == "mean": | |
| # get normals by cross product | |
| vector1 = points1 - anchors | |
| vector2 = points2 - anchors | |
| normals = torch.cross(vector1, vector2, dim=-1) # (B,8,H,W,3) | |
| normals = normals.mean(dim=1) # (B,H,W,3) | |
| else: | |
| raise NotImplementedError(mode) | |
| normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8) | |
| normals = normals.permute(0, 3, 1, 2) # (B,3,H,W) | |
| return normals | |