Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import sys | |
| import tqdm | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| filepath = os.path.split(os.path.abspath(__file__))[0] | |
| repopath = os.path.split(filepath)[0] | |
| sys.path.append(repopath) | |
| from lib import * | |
| from utils.misc import * | |
| from data.dataloader import * | |
| from data.custom_transforms import * | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| def _args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') | |
| parser.add_argument('--source', '-s', type=str) | |
| parser.add_argument('--dest', '-d', type=str, default=None) | |
| parser.add_argument('--type', '-t', type=str, default='map') | |
| parser.add_argument('--gpu', '-g', action='store_true', default=False) | |
| parser.add_argument('--jit', '-j', action='store_true', default=False) | |
| parser.add_argument('--verbose', '-v', action='store_true', default=False) | |
| return parser.parse_args() | |
| def get_format(source): | |
| img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))]) | |
| vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))]) | |
| if img_count * vid_count != 0: | |
| return '' | |
| elif img_count != 0: | |
| return 'Image' | |
| elif vid_count != 0: | |
| return 'Video' | |
| else: | |
| return '' | |
| def inference(opt, args): | |
| model = eval(opt.Model.name)(**opt.Model) | |
| model.load_state_dict(torch.load(os.path.join( | |
| opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True) | |
| if args.gpu is True: | |
| model = model.cuda() | |
| model.eval() | |
| if args.jit is True: | |
| if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False: | |
| model = Simplify(model) | |
| model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False) | |
| torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) | |
| else: | |
| del model | |
| model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) | |
| save_dir = None | |
| _format = None | |
| if args.source.isnumeric() is True: | |
| _format = 'Webcam' | |
| elif os.path.isdir(args.source): | |
| save_dir = os.path.join('results', args.source.split(os.sep)[-1]) | |
| _format = get_format(os.listdir(args.source)) | |
| elif os.path.isfile(args.source): | |
| save_dir = 'results' | |
| _format = get_format([args.source]) | |
| if args.dest is not None: | |
| save_dir = args.dest | |
| if save_dir is not None: | |
| os.makedirs(save_dir, exist_ok=True) | |
| sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms) | |
| if args.verbose is True: | |
| samples = tqdm.tqdm(sample_list, desc='Inference', total=len( | |
| sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') | |
| else: | |
| samples = sample_list | |
| writer = None | |
| background = None | |
| for sample in samples: | |
| if _format == 'Video' and writer is None: | |
| writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1]) | |
| samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if _format == 'Video' and sample['image'] is None: | |
| if writer is not None: | |
| writer.release() | |
| writer = None | |
| continue | |
| if args.gpu is True: | |
| sample = to_cuda(sample) | |
| with torch.no_grad(): | |
| if args.jit is True: | |
| out = model(sample['image']) | |
| else: | |
| out = model(sample) | |
| pred = to_numpy(out['pred'], sample['shape']) | |
| img = np.array(sample['original']) | |
| if args.type == 'map': | |
| img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) | |
| elif args.type == 'rgba': | |
| r, g, b = cv2.split(img) | |
| pred = (pred * 255).astype(np.uint8) | |
| img = cv2.merge([r, g, b, pred]) | |
| elif args.type == 'green': | |
| bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] | |
| img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) | |
| elif args.type == 'blur': | |
| img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis]) | |
| elif args.type == 'overlay': | |
| bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2 | |
| img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) | |
| border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100) | |
| img[border != 0] = [120, 255, 155] | |
| elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')): | |
| if background is None: | |
| background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB) | |
| background = cv2.resize(background, img.shape[:2][::-1]) | |
| img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis]) | |
| elif args.type == 'debug': | |
| debs = [] | |
| for k in opt.Train.Debug.keys: | |
| debs.extend(out[k]) | |
| for i, j in enumerate(debs): | |
| log = torch.sigmoid(j).cpu().detach().numpy().squeeze() | |
| log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) | |
| log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) | |
| log = cv2.resize(log, img.shape[:2][::-1]) | |
| Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png')) | |
| # size=img.shape[:2][::-1] | |
| img = img.astype(np.uint8) | |
| if _format == 'Image': | |
| Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png')) | |
| elif _format == 'Video' and writer is not None: | |
| writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| elif _format == 'Webcam': | |
| cv2.imshow('InSPyReNet', img) | |
| if __name__ == "__main__": | |
| args = _args() | |
| opt = load_config(args.config) | |
| inference(opt, args) | |