FE2E-CPU / infer /visualize.py
Nekochu's picture
FE2E depth+normal CPU Space: FP8 dynamic INT8, single denoise
405d2b1
Raw
History Blame Contribute Delete
4.91 kB
import cv2
import numpy as np
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger('root')
def tensor_to_numpy(tensor_in):
""" torch tensor to numpy array
"""
if tensor_in is not None:
if tensor_in.ndim == 3:
# (C, H, W) -> (H, W, C)
tensor_in = tensor_in.detach().cpu().permute(1, 2, 0).numpy()
elif tensor_in.ndim == 4:
# (B, C, H, W) -> (B, H, W, C)
tensor_in = tensor_in.detach().cpu().permute(0, 2, 3, 1).numpy()
else:
raise Exception('invalid tensor size')
return tensor_in
# def unnormalize(img_in, img_stats={'mean': [0.485, 0.456, 0.406],
# 'std': [0.229, 0.224, 0.225]}):
def unnormalize(img_in, img_stats={'mean': [0.5,0.5,0.5], 'std': [0.5,0.5,0.5]}):
""" unnormalize input image
"""
if torch.is_tensor(img_in):
img_in = tensor_to_numpy(img_in)
# 检查输入图像的数值范围,决定是否需要去归一化
img_min, img_max = img_in.min(), img_in.max()
# 如果图像已经在[0,1]范围内,直接转换为[0,255]
if img_min >= -0.1 and img_max <= 1.1: # 允许小的浮点误差
img_out = np.clip(img_in, 0, 1)
img_out = (img_out * 255.0).astype(np.uint8)
else:
# 如果图像在[-1,1]或其他归一化范围内,进行标准去归一化
img_out = np.zeros_like(img_in)
for ich in range(3):
img_out[..., ich] = img_in[..., ich] * img_stats['std'][ich]
img_out[..., ich] += img_stats['mean'][ich]
img_out = np.clip(img_out, 0, 1)
img_out = (img_out * 255.0).astype(np.uint8)
return img_out
def normal_to_rgb(normal, normal_mask=None):
""" surface normal map to RGB
(used for visualization)
NOTE: x, y, z are mapped to R, G, B
NOTE: [-1, 1] are mapped to [0, 255]
"""
if torch.is_tensor(normal):
normal = tensor_to_numpy(normal)
normal_mask = tensor_to_numpy(normal_mask)
normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
normal_norm[normal_norm < 1e-12] = 1e-12
normal = normal / normal_norm
normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
if normal_mask is not None:
normal_rgb = normal_rgb * normal_mask # (B, H, W, 3)
return normal_rgb
def kappa_to_alpha(pred_kappa, to_numpy=True):
""" Confidence kappa to uncertainty alpha
Assuming AngMF distribution (introduced in https://arxiv.org/abs/2109.09881)
"""
if torch.is_tensor(pred_kappa) and to_numpy:
pred_kappa = tensor_to_numpy(pred_kappa)
if torch.is_tensor(pred_kappa):
alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
+ ((torch.exp(- pred_kappa * np.pi) * np.pi) / (1 + torch.exp(- pred_kappa * np.pi)))
alpha = torch.rad2deg(alpha)
else:
alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
+ ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
alpha = np.degrees(alpha)
return alpha
def visualize_normal(target_dir, prefixs, img, pred_norm, pred_kappa,
gt_norm, gt_norm_mask, pred_error, num_vis=-1):
""" visualize normal
"""
error_max = 60.0
# img = tensor_to_numpy(img) # (B, H, W, 3)
pred_norm = tensor_to_numpy(pred_norm) # (B, H, W, 3)
# pred_kappa = tensor_to_numpy(pred_kappa) # (B, H, W, 1)
gt_norm = tensor_to_numpy(gt_norm) # (B, H, W, 3)
gt_norm_mask = tensor_to_numpy(gt_norm_mask) # (B, H, W, 1)
pred_error = tensor_to_numpy(pred_error) # (B, H, W, 1)
num_vis = len(prefixs) if num_vis == -1 else num_vis
for i in range(num_vis):
# # img
# img_ = unnormalize(img[i, ...])
# target_path = '%s/%s_img.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, img_)
# pred_norm
target_path = '%s/%s_norm.png' % (target_dir, prefixs[i])
plt.imsave(target_path, normal_to_rgb(pred_norm[i, ...]))
# # pred_kappa
# if pred_kappa is not None:
# pred_alpha = kappa_to_alpha(pred_kappa[i, :, :, 0])
# target_path = '%s/%s_pred_alpha.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, pred_alpha, vmin=0.0, vmax=error_max, cmap='jet')
# gt_norm, pred_error
if gt_norm is not None:
target_path = '%s/%s_gt.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, normal_to_rgb(gt_norm[i, ...], gt_norm_mask[i, ...]))
E = pred_error[i, :, :, 0] * gt_norm_mask[i, :, :, 0]
target_path = '%s/%s_pred_error.png' % (target_dir, prefixs[i])
plt.imsave(target_path, E, vmin=0, vmax=error_max, cmap='jet')