File size: 2,302 Bytes
0d1388f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn.functional as F

from .graphics_utils import depth2point_cam


def get_normal_sign(normals, begin=None, end=None, trans=None, mode='origin', vec=None):
    if mode == 'origin':
        if vec is None:
            if begin is None:
                # center
                if trans is not None:
                    begin = - trans[:3, :3].T @ trans[:3, 3] \
                        if trans.ndim != 1 else trans
                else:
                    begin = end.mean(0)
                begin[1] += 1
            vec = end - begin
        cos = (normals * vec).sum(-1, keepdim=True)
    
    return cos


def compute_gradient(img):
    dy = torch.gradient(img, dim=0)[0]
    dx = torch.gradient(img, dim=1)[0]
    return dx, dy


def compute_normals(depth_map, K):
    # Assuming depth_map is a PyTorch tensor of shape [H, W]
    # K_inv is the inverse of the intrinsic matrix
    
    _, cam_coords = depth2point_cam(depth_map[None, None], K[None])
    cam_coords = cam_coords.squeeze(0).squeeze(0).squeeze(0)        # [H, W, 3]
    
    dx, dy = compute_gradient(cam_coords)
    # Cross product of gradients gives normal
    normals = torch.cross(dx, dy, dim=-1)
    normals = F.normalize(normals, p=2, dim=-1)
    return normals
    

def compute_edge(image, k=11, thr=0.01):
    dx, dy = compute_gradient(image)
    
    edge = torch.sqrt(dx**2 + dy**2)
    edge = edge / edge.max()
    
    p = (k - 1) // 2
    edge = F.max_pool2d(edge[None], kernel_size=k, stride=1, padding=p)[0]
        
    edge[edge>thr] = 1
    return edge


def get_edge_aware_distortion_map(gt_image, distortion_map):
    grad_img_left = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, :-2]), 0)
    grad_img_right = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, 2:]), 0)
    grad_img_top = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, :-2, 1:-1]), 0)
    grad_img_bottom = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 2:, 1:-1]), 0)
    max_grad = torch.max(torch.stack([grad_img_left, grad_img_right, grad_img_top, grad_img_bottom], dim=-1), dim=-1)[0]
    # pad
    max_grad = torch.exp(-max_grad)
    max_grad = torch.nn.functional.pad(max_grad, (1, 1, 1, 1), mode="constant", value=0)
    return distortion_map * max_grad