| import os |
| import sys |
| import numpy as np |
| from tqdm import tqdm |
| import glob |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| import time |
| from torchvision import transforms |
| import cv2 |
| from PIL import Image |
|
|
| import sys |
| sys.path.append('../../') |
| import utils.utils as utils |
| import utils.visualize as vis_utils |
|
|
| |
| |
| import projects.dsine.config as config |
| from projects.baseline_normal.dataloader import * |
| from utils.projection import intrins_from_fov, intrins_from_txt |
| |
|
|
|
|
| def test(args, model, test_loader, device, results_dir=None): |
| with torch.no_grad(): |
| total_normal_errors = None |
|
|
| for data_dict in tqdm(test_loader): |
|
|
| |
| |
| img = data_dict['img'].to(device) |
| scene_names = data_dict['scene_name'] |
| img_names = data_dict['img_name'] |
| intrins = data_dict['intrins'].to(device) |
|
|
| |
| _, _, orig_H, orig_W = img.shape |
| lrtb = utils.get_padding(orig_H, orig_W) |
| img, intrins = utils.pad_input(img, intrins, lrtb) |
|
|
| |
| pred_list = model(img, intrins=intrins, mode='test') |
| norm_out = pred_list[-1] |
|
|
| |
| norm_out = norm_out[:, :, lrtb[2]:lrtb[2]+orig_H, lrtb[0]:lrtb[0]+orig_W] |
|
|
| pred_norm, pred_kappa = norm_out[:, :3, :, :], norm_out[:, 3:, :, :] |
| pred_kappa = None if pred_kappa.size(1) == 0 else pred_kappa |
| |
|
|
| if 'normal' in data_dict.keys(): |
| gt_norm = data_dict['normal'].to(device) |
| gt_norm_mask = data_dict['normal_mask'].to(device) |
|
|
| pred_error = utils.compute_normal_error(pred_norm, gt_norm) |
| if total_normal_errors is None: |
| total_normal_errors = pred_error[gt_norm_mask] |
| else: |
| total_normal_errors = torch.cat((total_normal_errors, pred_error[gt_norm_mask]), dim=0) |
|
|
| if results_dir is not None: |
| prefixs = ['%s_%s' % (i,j) for (i,j) in zip(scene_names, img_names)] |
| vis_utils.visualize_normal(results_dir, prefixs, img, pred_norm, pred_kappa, |
| gt_norm, gt_norm_mask, pred_error) |
|
|
| if total_normal_errors is not None: |
| metrics = utils.compute_normal_metrics(total_normal_errors) |
| print("mean median rmse 5 7.5 11.25 22.5 30") |
| print("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % ( |
| metrics['mean'], metrics['median'], metrics['rmse'], |
| metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5'])) |
|
|
|
|
| def test_samples(args, model, device): |
| img_paths = glob.glob('./samples/img/*.png') + glob.glob('./samples/img/*.jpg') |
| img_paths.sort() |
| os.makedirs('./samples/output/', exist_ok=True) |
|
|
| |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
| with torch.no_grad(): |
| for img_path in img_paths: |
| print(img_path) |
| ext = os.path.splitext(img_path)[1] |
| img = Image.open(img_path).convert('RGB') |
| img = np.array(img).astype(np.float32) / 255.0 |
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device) |
|
|
| |
| |
|
|
| |
| _, _, orig_H, orig_W = img.shape |
| lrtb = utils.get_padding(orig_H, orig_W) |
| img = F.pad(img, lrtb, mode="constant", value=0.0) |
| img = normalize(img) |
|
|
| |
| intrins_path = img_path.replace(ext, '.txt') |
| if os.path.exists(intrins_path): |
| |
| |
| intrins = intrins_from_txt(intrins_path, device=device).unsqueeze(0) |
| else: |
| |
| |
| intrins = intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=device).unsqueeze(0) |
| intrins[:, 0, 2] += lrtb[0] |
| intrins[:, 1, 2] += lrtb[2] |
|
|
| norm_out = model(img, intrins=intrins, mode='test')[-1] |
| norm_out = norm_out[:, :, lrtb[2]:lrtb[2]+orig_H, lrtb[0]:lrtb[0]+orig_W] |
| pred_norm = norm_out[:, :3, :, :] |
| |
|
|
| |
| |
| |
| target_path = img_path.replace('/img/', '/output/').replace(ext, '.png') |
| im = Image.fromarray(vis_utils.normal_to_rgb(pred_norm)[0,...]) |
| im.save(target_path) |
|
|
|
|
| def measure_throughput(model, img, intrins, dtype='fp32', nwarmup=50, nruns=1000): |
| img = img.to("cuda") |
| intrins = intrins.to("cuda") |
| if dtype=='fp16': |
| img = img.half() |
| intrins = intrins.half() |
| |
| print("Warm up ...") |
| with torch.no_grad(): |
| for _ in range(nwarmup): |
| norm_outs = model(img, intrins, mode='test') |
|
|
| torch.cuda.synchronize() |
| print("Start timing ...") |
| timings = [] |
| with torch.no_grad(): |
| for i in range(1, nruns+1): |
| start_time = time.time() |
| norm_outs = model(img, intrins, mode='test') |
| torch.cuda.synchronize() |
| end_time = time.time() |
| timings.append(end_time - start_time) |
| if i%10==0: |
| print('Iteration %d/%d, avg batch time %.2f ms'%(i, nruns, np.mean(timings)*1000)) |
|
|
| print("Input shape:", img.size()) |
| print('Average throughput: %.2f images/second'%(img.shape[0]/np.mean(timings))) |
|
|
|
|
| def demo(args, model, InputStream, frame_name): |
| cv2.namedWindow(frame_name, cv2.WINDOW_NORMAL) |
| cv2.setWindowProperty(frame_name, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) |
| pause = False |
|
|
| lrtb = InputStream.lrtb |
| H, W = InputStream.new_H, InputStream.new_W |
|
|
| while True: |
| with torch.no_grad(): |
| if pause: |
| pass |
| else: |
| data_dict = InputStream.get_sample() |
| color_image = data_dict['color_image'] |
|
|
| |
| |
|
|
| img = data_dict['img'] |
| intrins = data_dict['intrins'] |
|
|
| norm_out = model(img, intrins=intrins, mode='test')[-1] |
| |
| norm_out = norm_out[:, :, lrtb[2]:lrtb[2]+H, lrtb[0]:lrtb[0]+W] |
| pred_norm = norm_out[:, :3, :, :] |
| pred_kappa = norm_out[:, 3:, :, :] |
| |
|
|
| |
| pred_norm_rgb = vis_utils.normal_to_rgb(pred_norm)[0,...][...,::-1] |
| if pred_kappa.size(1) == 0: |
| pred_uncertainty = [] |
| elif 'NLL_angmf' in args.loss_fn: |
| pred_uncertainty = [vis_utils.alpha_to_jet(vis_utils.kappa_to_alpha(pred_kappa))] |
| else: |
| pred_uncertainty = [vis_utils.depth_to_rgb(pred_kappa[0,...], None, 0.0, None, 'gray')[...,::-1]] |
| out = np.hstack([color_image, pred_norm_rgb]+pred_uncertainty) |
|
|
| cv2.imshow(frame_name, out) |
|
|
| |
| k = cv2.waitKey(1) |
| if k == ord(' '): |
| pause = not pause |
| elif k == ord('q'): |
| exit() |
|
|
|
|
| if __name__ == '__main__': |
| device = torch.device('cuda') |
| args = config.get_args(test=True) |
|
|
| if args.ckpt_path is None: |
| ckpt_paths = glob.glob(os.path.join(args.output_dir, 'models', '*.pt')) |
| ckpt_paths.sort() |
| args.ckpt_path = ckpt_paths[-1] |
| assert os.path.exists(args.ckpt_path) |
|
|
| |
| |
| if args.NNET_architecture == 'v00': |
| from models.dsine.v00 import DSINE_v00 as DSINE |
| elif args.NNET_architecture == 'v01': |
| from models.dsine.v01 import DSINE_v01 as DSINE |
| elif args.NNET_architecture == 'v02': |
| from models.dsine.v02 import DSINE_v02 as DSINE |
| elif args.NNET_architecture == 'v02_kappa': |
| from models.dsine.v02_kappa import DSINE_v02_kappa as DSINE |
| else: |
| raise Exception('invalid arch') |
| model = DSINE(args).to(device) |
|
|
| model = utils.load_checkpoint(args.ckpt_path, model) |
| model.eval() |
| |
| |
| |
| if args.mode == 'benchmark': |
| |
| args.input_height = args.input_width = 0 |
| args.data_augmentation_same_fov = 0 |
|
|
| for dataset_name, split in [('nyuv2', 'test'), |
| ('scannet', 'test'), |
| ('ibims', 'ibims'), |
| ('sintel', 'sintel'), |
| ('vkitti', 'vkitti'), |
| ('oasis', 'val') |
| ]: |
|
|
| args.dataset_name_test = dataset_name |
| args.test_split = split |
| test_loader = TestLoader(args).data |
|
|
| results_dir = None |
| if args.visualize: |
| results_dir = os.path.join(args.output_dir, 'test', dataset_name) |
| os.makedirs(results_dir, exist_ok=True) |
| |
| test(args, model, test_loader, device, results_dir) |
|
|
| |
| elif args.mode == 'samples': |
| test_samples(args, model, device) |
|
|
| |
| |
| elif args.mode == 'throughput': |
| H, W = 480, 640 |
| batch_size = 8 |
| dummy_img = torch.rand(batch_size, 3, H, W).float().to(device) |
| dummy_intrins = torch.cat([intrins_from_fov(60.0, H, W).unsqueeze(0)]*batch_size, dim=0).float().to(device) |
| measure_throughput(model, dummy_img, dummy_intrins, dtype='fp32') |
| |
|
|
| |
| else: |
| from utils.demo_data import define_input |
| if args.mode == 'screen': |
| input_name = 'screen' |
| kwargs = dict( |
| intrins = None, |
| top = (1080-480) // 2, |
| left = (1920-640) // 2, |
| height = 480, |
| width = 640, |
| ) |
|
|
| elif args.mode == 'webcam': |
| input_name = 'webcam' |
| kwargs = dict( |
| intrins = None, |
| new_width = -1, |
| webcam_index = 1, |
| ) |
|
|
| elif args.mode == 'rs': |
| input_name = 'rs' |
| kwargs = dict( |
| enable_auto_exposure = True, |
| enable_auto_white_balance = True, |
| ) |
|
|
| elif 'youtube.com' in args.mode: |
| input_name = 'youtube' |
| kwargs = dict( |
| intrins = None, |
| new_width = 1024, |
| video_id = args.mode.split('watch?v=')[1], |
| ) |
|
|
| else: |
| raise Exception('invalid input option for demo') |
|
|
| InputStream = define_input(input=input_name, device=device, **kwargs) |
| demo(args, model, InputStream, frame_name=args.ckpt_path) |
|
|
|
|
|
|