| import torch |
| import torch.nn.functional as F |
|
|
| from .graphics_utils import depth2point_cam |
|
|
|
|
| def get_normal_sign(normals, begin=None, end=None, trans=None, mode='origin', vec=None): |
| if mode == 'origin': |
| if vec is None: |
| if begin is None: |
| |
| if trans is not None: |
| begin = - trans[:3, :3].T @ trans[:3, 3] \ |
| if trans.ndim != 1 else trans |
| else: |
| begin = end.mean(0) |
| begin[1] += 1 |
| vec = end - begin |
| cos = (normals * vec).sum(-1, keepdim=True) |
| |
| return cos |
|
|
|
|
| def compute_gradient(img): |
| dy = torch.gradient(img, dim=0)[0] |
| dx = torch.gradient(img, dim=1)[0] |
| return dx, dy |
|
|
|
|
| def compute_normals(depth_map, K): |
| |
| |
| |
| _, cam_coords = depth2point_cam(depth_map[None, None], K[None]) |
| cam_coords = cam_coords.squeeze(0).squeeze(0).squeeze(0) |
| |
| dx, dy = compute_gradient(cam_coords) |
| |
| normals = torch.cross(dx, dy, dim=-1) |
| normals = F.normalize(normals, p=2, dim=-1) |
| return normals |
| |
|
|
| def compute_edge(image, k=11, thr=0.01): |
| dx, dy = compute_gradient(image) |
| |
| edge = torch.sqrt(dx**2 + dy**2) |
| edge = edge / edge.max() |
| |
| p = (k - 1) // 2 |
| edge = F.max_pool2d(edge[None], kernel_size=k, stride=1, padding=p)[0] |
| |
| edge[edge>thr] = 1 |
| return edge |
|
|
|
|
| def get_edge_aware_distortion_map(gt_image, distortion_map): |
| grad_img_left = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, :-2]), 0) |
| grad_img_right = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, 2:]), 0) |
| grad_img_top = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, :-2, 1:-1]), 0) |
| grad_img_bottom = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 2:, 1:-1]), 0) |
| max_grad = torch.max(torch.stack([grad_img_left, grad_img_right, grad_img_top, grad_img_bottom], dim=-1), dim=-1)[0] |
| |
| max_grad = torch.exp(-max_grad) |
| max_grad = torch.nn.functional.pad(max_grad, (1, 1, 1, 1), mode="constant", value=0) |
| return distortion_map * max_grad |