| import matplotlib.pyplot as plt |
| import os, cv2 |
| import numpy as np |
| from mono.utils.transform import gray_to_colormap |
| import shutil |
| import glob |
| from mono.utils.running import main_process |
| import torch |
| from html4vision import Col, imagetable |
|
|
| def save_raw_imgs( |
| pred: torch.tensor, |
| rgb: torch.tensor, |
| filename: str, |
| save_dir: str, |
| scale: float=200.0, |
| target: torch.tensor=None, |
| ): |
| """ |
| Save raw GT, predictions, RGB in the same file. |
| """ |
| cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb) |
| cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16)) |
| if target is not None: |
| cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16)) |
| |
|
|
| def save_val_imgs( |
| iter: int, |
| pred: torch.tensor, |
| target: torch.tensor, |
| rgb: torch.tensor, |
| filename: str, |
| save_dir: str, |
| tb_logger=None |
| ): |
| """ |
| Save GT, predictions, RGB in the same file. |
| """ |
| rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb) |
| rgb = rgb.transpose((1, 2, 0)) |
| cat_img = np.concatenate([rgb, pred_color, target_color], axis=0) |
| plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) |
|
|
| |
| if tb_logger is not None: |
| tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) |
|
|
| def save_normal_val_imgs( |
| iter: int, |
| pred: torch.tensor, |
| targ: torch.tensor, |
| rgb: torch.tensor, |
| filename: str, |
| save_dir: str, |
| tb_logger=None, |
| mask=None, |
| ): |
| """ |
| Save GT, predictions, RGB in the same file. |
| """ |
| mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :] |
| std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :] |
| pred = pred.squeeze() |
| targ = targ.squeeze() |
| rgb = rgb.squeeze() |
|
|
| if pred.size(0) == 3: |
| pred = pred.permute(1,2,0) |
| if targ.size(0) == 3: |
| targ = targ.permute(1,2,0) |
| if rgb.size(0) == 3: |
| rgb = rgb.permute(1,2,0) |
|
|
| pred_color = vis_surface_normal(pred, mask) |
| targ_color = vis_surface_normal(targ, mask) |
| rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8) |
|
|
| try: |
| cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) |
| except: |
| pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0])) |
| targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0])) |
| cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) |
|
|
| plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) |
| |
| |
| if tb_logger is not None: |
| tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) |
|
|
| def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor): |
| mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis] |
| std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis] |
|
|
| pred = pred.squeeze().cpu().numpy() |
| target = target.squeeze().cpu().numpy() |
| rgb = rgb.squeeze().cpu().numpy() |
|
|
| pred[pred<0] = 0 |
| target[target<0] = 0 |
| max_scale = max(pred.max(), target.max()) |
| pred_scale = (pred/max_scale * 10000).astype(np.uint16) |
| target_scale = (target/max_scale * 10000).astype(np.uint16) |
| pred_color = gray_to_colormap(pred) |
| target_color = gray_to_colormap(target) |
| pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1])) |
| target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1])) |
|
|
| rgb = ((rgb * std) + mean).astype(np.uint8) |
| return rgb, pred_scale, target_scale, pred_color, target_color |
|
|
|
|
| def create_html(name2path, save_path='index.html', size=(256, 384)): |
| |
| cols = [] |
| for k, v in name2path.items(): |
| col_i = Col('img', k, v) |
| cols.append(col_i) |
| |
| imagetable(cols, out_file=save_path, imsize=size) |
|
|
| def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array: |
| """ |
| Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] |
| Aargs: |
| normal (torch.tensor, [h, w, 3]): surface normal |
| mask (torch.tensor, [h, w]): valid masks |
| """ |
| normal = normal.cpu().numpy().squeeze() |
| n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) |
| n_img_norm = normal / (n_img_L2 + 1e-8) |
| normal_vis = n_img_norm * 127 |
| normal_vis += 128 |
| normal_vis = normal_vis.astype(np.uint8) |
| if mask is not None: |
| mask = mask.cpu().numpy().squeeze() |
| normal_vis[~mask] = 0 |
| return normal_vis |
|
|
|
|