Spaces:
Running
Running
| import torch | |
| import torch.utils.data | |
| import numpy as np | |
| import torchvision.utils as vutils | |
| import cv2 | |
| from matplotlib.cm import get_cmap | |
| import matplotlib as mpl | |
| import matplotlib.cm as cm | |
| def vis_disparity(disp, return_rgb=False): | |
| disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 | |
| disp_vis = disp_vis.astype("uint8") | |
| disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) | |
| if return_rgb: | |
| disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB) | |
| return disp_vis | |
| def gen_error_colormap(): | |
| cols = np.array( | |
| [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], | |
| [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], | |
| [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], | |
| [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], | |
| [1.5 / 3.0, 3 / 3.0, 224, 243, 248], | |
| [3 / 3.0, 6 / 3.0, 254, 224, 144], | |
| [6 / 3.0, 12 / 3.0, 253, 174, 97], | |
| [12 / 3.0, 24 / 3.0, 244, 109, 67], | |
| [24 / 3.0, 48 / 3.0, 215, 48, 39], | |
| [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) | |
| cols[:, 2: 5] /= 255. | |
| return cols | |
| def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): | |
| D_gt_np = D_gt_tensor.detach().cpu().numpy() | |
| D_est_np = D_est_tensor.detach().cpu().numpy() | |
| B, H, W = D_gt_np.shape | |
| # valid mask | |
| mask = D_gt_np > 0 | |
| # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% | |
| error = np.abs(D_gt_np - D_est_np) | |
| error[np.logical_not(mask)] = 0 | |
| error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) | |
| # get colormap | |
| cols = gen_error_colormap() | |
| # create error image | |
| error_image = np.zeros([B, H, W, 3], dtype=np.float32) | |
| for i in range(cols.shape[0]): | |
| error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] | |
| # TODO: imdilate | |
| # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); | |
| error_image[np.logical_not(mask)] = 0. | |
| # show color tag in the top-left cornor of the image | |
| for i in range(cols.shape[0]): | |
| distance = 20 | |
| error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] | |
| return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) | |
| def save_images(logger, mode_tag, images_dict, global_step): | |
| images_dict = tensor2numpy(images_dict) | |
| for tag, values in images_dict.items(): | |
| if not isinstance(values, list) and not isinstance(values, tuple): | |
| values = [values] | |
| for idx, value in enumerate(values): | |
| if len(value.shape) == 3: | |
| value = value[:, np.newaxis, :, :] | |
| value = value[:1] | |
| value = torch.from_numpy(value) | |
| image_name = '{}/{}'.format(mode_tag, tag) | |
| if len(values) > 1: | |
| image_name = image_name + "_" + str(idx) | |
| logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), | |
| global_step) | |
| def tensor2numpy(var_dict): | |
| for key, vars in var_dict.items(): | |
| if isinstance(vars, np.ndarray): | |
| var_dict[key] = vars | |
| elif isinstance(vars, torch.Tensor): | |
| var_dict[key] = vars.data.cpu().numpy() | |
| else: | |
| raise NotImplementedError("invalid input type for tensor2numpy") | |
| return var_dict | |
| def viz_depth_tensor_from_monodepth2(disp, return_numpy=False, colormap='plasma'): | |
| # visualize inverse depth | |
| assert isinstance(disp, torch.Tensor) | |
| disp = disp.numpy() | |
| vmax = np.percentile(disp, 95) | |
| normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) | |
| mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) | |
| colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] | |
| if return_numpy: | |
| return colormapped_im | |
| viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] | |
| return viz | |