NeAR / trellis /renderers /normal_utils.py
luh1124's picture
restore: full Space tree + assets (recover from minimal force-push); keep ZeroGPU app.py
0d1388f
import torch
import torch.nn.functional as F
from .graphics_utils import depth2point_cam
def get_normal_sign(normals, begin=None, end=None, trans=None, mode='origin', vec=None):
if mode == 'origin':
if vec is None:
if begin is None:
# center
if trans is not None:
begin = - trans[:3, :3].T @ trans[:3, 3] \
if trans.ndim != 1 else trans
else:
begin = end.mean(0)
begin[1] += 1
vec = end - begin
cos = (normals * vec).sum(-1, keepdim=True)
return cos
def compute_gradient(img):
dy = torch.gradient(img, dim=0)[0]
dx = torch.gradient(img, dim=1)[0]
return dx, dy
def compute_normals(depth_map, K):
# Assuming depth_map is a PyTorch tensor of shape [H, W]
# K_inv is the inverse of the intrinsic matrix
_, cam_coords = depth2point_cam(depth_map[None, None], K[None])
cam_coords = cam_coords.squeeze(0).squeeze(0).squeeze(0) # [H, W, 3]
dx, dy = compute_gradient(cam_coords)
# Cross product of gradients gives normal
normals = torch.cross(dx, dy, dim=-1)
normals = F.normalize(normals, p=2, dim=-1)
return normals
def compute_edge(image, k=11, thr=0.01):
dx, dy = compute_gradient(image)
edge = torch.sqrt(dx**2 + dy**2)
edge = edge / edge.max()
p = (k - 1) // 2
edge = F.max_pool2d(edge[None], kernel_size=k, stride=1, padding=p)[0]
edge[edge>thr] = 1
return edge
def get_edge_aware_distortion_map(gt_image, distortion_map):
grad_img_left = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, :-2]), 0)
grad_img_right = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, 2:]), 0)
grad_img_top = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, :-2, 1:-1]), 0)
grad_img_bottom = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 2:, 1:-1]), 0)
max_grad = torch.max(torch.stack([grad_img_left, grad_img_right, grad_img_top, grad_img_bottom], dim=-1), dim=-1)[0]
# pad
max_grad = torch.exp(-max_grad)
max_grad = torch.nn.functional.pad(max_grad, (1, 1, 1, 1), mode="constant", value=0)
return distortion_map * max_grad