Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| # convert arg line to args | |
| def convert_arg_line_to_args(arg_line): | |
| for arg in arg_line.split(): | |
| if not arg.strip(): | |
| continue | |
| yield str(arg) | |
| # save args | |
| def save_args(args, filename): | |
| with open(filename, 'w') as f: | |
| for arg in vars(args): | |
| f.write('{}: {}\n'.format(arg, getattr(args, arg))) | |
| # concatenate images | |
| def concat_image(image_path_list, concat_image_path): | |
| imgs = [Image.open(i).convert("RGB").resize((640, 480), resample=Image.BILINEAR) for i in image_path_list] | |
| imgs_list = [] | |
| for i in range(len(imgs)): | |
| img = imgs[i] | |
| imgs_list.append(np.asarray(img)) | |
| H, W, _ = np.asarray(img).shape | |
| imgs_list.append(255 * np.ones((H, 20, 3)).astype('uint8')) | |
| imgs_comb = np.hstack(imgs_list[:-1]) | |
| imgs_comb = Image.fromarray(imgs_comb) | |
| imgs_comb.save(concat_image_path) | |
| # load model | |
| def load_checkpoint(fpath, model): | |
| ckpt = torch.load(fpath, map_location='cpu')['model'] | |
| load_dict = {} | |
| for k, v in ckpt.items(): | |
| if k.startswith('module.'): | |
| k_ = k.replace('module.', '') | |
| load_dict[k_] = v | |
| else: | |
| load_dict[k] = v | |
| model.load_state_dict(load_dict) | |
| return model | |
| # compute normal errors | |
| def compute_normal_errors(total_normal_errors): | |
| metrics = { | |
| 'mean': np.average(total_normal_errors), | |
| 'median': np.median(total_normal_errors), | |
| 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / total_normal_errors.shape), | |
| 'a1': 100.0 * (np.sum(total_normal_errors < 5) / total_normal_errors.shape[0]), | |
| 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / total_normal_errors.shape[0]), | |
| 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / total_normal_errors.shape[0]), | |
| 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / total_normal_errors.shape[0]), | |
| 'a5': 100.0 * (np.sum(total_normal_errors < 30) / total_normal_errors.shape[0]) | |
| } | |
| return metrics | |
| # log normal errors | |
| def log_normal_errors(metrics, where_to_write, first_line): | |
| print(first_line) | |
| print("mean median rmse 5 7.5 11.25 22.5 30") | |
| print("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % ( | |
| metrics['mean'], metrics['median'], metrics['rmse'], | |
| metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5'])) | |
| with open(where_to_write, 'a') as f: | |
| f.write('%s\n' % first_line) | |
| f.write("mean median rmse 5 7.5 11.25 22.5 30\n") | |
| f.write("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f\n\n" % ( | |
| metrics['mean'], metrics['median'], metrics['rmse'], | |
| metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5'])) | |
| # makedir | |
| def makedir(dirpath): | |
| if not os.path.exists(dirpath): | |
| os.makedirs(dirpath) | |
| # makedir from list | |
| def make_dir_from_list(dirpath_list): | |
| for dirpath in dirpath_list: | |
| makedir(dirpath) | |
| ######################################################################################################################## | |
| # Visualization | |
| ######################################################################################################################## | |
| # unnormalize image | |
| __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} | |
| def unnormalize(img_in): | |
| img_out = np.zeros(img_in.shape) | |
| for ich in range(3): | |
| img_out[:, :, ich] = img_in[:, :, ich] * __imagenet_stats['std'][ich] | |
| img_out[:, :, ich] += __imagenet_stats['mean'][ich] | |
| img_out = (img_out * 255).astype(np.uint8) | |
| return img_out | |
| # kappa to exp error (only applicable to AngMF distribution) | |
| def kappa_to_alpha(pred_kappa): | |
| alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \ | |
| + ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi))) | |
| alpha = np.degrees(alpha) | |
| return alpha | |
| # normal vector to rgb values | |
| def norm_to_rgb(norm): | |
| # norm: (B, H, W, 3) | |
| norm_rgb = ((norm[0, ...] + 1) * 0.5) * 255 | |
| norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255) | |
| norm_rgb = norm_rgb.astype(np.uint8) | |
| return norm_rgb | |
| # visualize during training | |
| def visualize(args, img, gt_norm, gt_norm_mask, norm_out_list, total_iter): | |
| B, _, H, W = gt_norm.shape | |
| pred_norm_list = [] | |
| pred_kappa_list = [] | |
| for norm_out in norm_out_list: | |
| norm_out = F.interpolate(norm_out, size=[gt_norm.size(2), gt_norm.size(3)], mode='nearest') | |
| pred_norm = norm_out[:, :3, :, :] # (B, 3, H, W) | |
| pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3) | |
| pred_norm_list.append(pred_norm) | |
| pred_kappa = norm_out[:, 3:, :, :] # (B, 1, H, W) | |
| pred_kappa = pred_kappa.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1) | |
| pred_kappa_list.append(pred_kappa) | |
| # to numpy arrays | |
| img = img.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3) | |
| gt_norm = gt_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3) | |
| gt_norm_mask = gt_norm_mask.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1) | |
| # input image | |
| target_path = '%s/%08d_img.jpg' % (args.exp_vis_dir, total_iter) | |
| img = unnormalize(img[0, ...]) | |
| plt.imsave(target_path, img) | |
| # gt norm | |
| gt_norm_rgb = ((gt_norm[0, ...] + 1) * 0.5) * 255 | |
| gt_norm_rgb = np.clip(gt_norm_rgb, a_min=0, a_max=255) | |
| gt_norm_rgb = gt_norm_rgb.astype(np.uint8) | |
| target_path = '%s/%08d_gt_norm.jpg' % (args.exp_vis_dir, total_iter) | |
| plt.imsave(target_path, gt_norm_rgb * gt_norm_mask[0, ...]) | |
| # pred_norm | |
| for i in range(len(pred_norm_list)): | |
| pred_norm = pred_norm_list[i] | |
| pred_norm_rgb = norm_to_rgb(pred_norm) | |
| target_path = '%s/%08d_pred_norm_%d.jpg' % (args.exp_vis_dir, total_iter, i) | |
| plt.imsave(target_path, pred_norm_rgb) | |
| pred_kappa = pred_kappa_list[i] | |
| pred_alpha = kappa_to_alpha(pred_kappa) | |
| target_path = '%s/%08d_pred_alpha_%d.jpg' % (args.exp_vis_dir, total_iter, i) | |
| plt.imsave(target_path, pred_alpha[0, :, :, 0], vmin=0, vmax=60, cmap='jet') | |
| # error in angles | |
| DP = np.sum(gt_norm * pred_norm, axis=3, keepdims=True) # (B, H, W, 1) | |
| DP = np.clip(DP, -1, 1) | |
| E = np.degrees(np.arccos(DP)) # (B, H, W, 1) | |
| E = E * gt_norm_mask | |
| target_path = '%s/%08d_pred_error_%d.jpg' % (args.exp_vis_dir, total_iter, i) | |
| plt.imsave(target_path, E[0, :, :, 0], vmin=0, vmax=60, cmap='jet') | |