Spaces:
Sleeping
Sleeping
| # Copyright (C) 2020 * Ltd. All rights reserved. | |
| # author : Sanghyeon Jo <josanghyeokn@gmail.com> | |
| import gradio as gr | |
| import os | |
| import sys | |
| import copy | |
| import shutil | |
| import random | |
| import argparse | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.utils.data import DataLoader | |
| from core.puzzle_utils import * | |
| from core.networks import * | |
| from core.datasets import * | |
| from tools.general.io_utils import * | |
| from tools.general.time_utils import * | |
| from tools.general.json_utils import * | |
| from tools.ai.log_utils import * | |
| from tools.ai.demo_utils import * | |
| from tools.ai.optim_utils import * | |
| from tools.ai.torch_utils import * | |
| from tools.ai.evaluate_utils import * | |
| from tools.ai.augment_utils import * | |
| from tools.ai.randaugment import * | |
| import PIL.Image | |
| parser = argparse.ArgumentParser() | |
| ############################################################################### | |
| # Dataset | |
| ############################################################################### | |
| parser.add_argument('--seed', default=2606, type=int) | |
| parser.add_argument('--num_workers', default=4, type=int) | |
| ############################################################################### | |
| # Network | |
| ############################################################################### | |
| parser.add_argument('--architecture', default='DeepLabv3+', type=str) | |
| parser.add_argument('--backbone', default='resnet50', type=str) | |
| parser.add_argument('--mode', default='fix', type=str) | |
| parser.add_argument('--use_gn', default=True, type=str2bool) | |
| ############################################################################### | |
| # Inference parameters | |
| ############################################################################### | |
| parser.add_argument('--tag', default='', type=str) | |
| parser.add_argument('--domain', default='val', type=str) | |
| parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str) | |
| parser.add_argument('--iteration', default=10, type=int) | |
| class_names = [ | |
| "aeroplane", | |
| "bicycle", | |
| "bird", | |
| "boat", | |
| "bottle", | |
| "bus", | |
| "car", | |
| "cat", | |
| "chair", | |
| "cow", | |
| "diningtable", | |
| "dog", | |
| "horse", | |
| "motorbike", | |
| "person", | |
| "pottedplant", | |
| "sheep", | |
| "sofa", | |
| "train", | |
| "tvmonitor" | |
| ] | |
| cmap_dic = { | |
| "background": [ | |
| 0, | |
| 0, | |
| 0 | |
| ], | |
| "aeroplane": [ | |
| 128, | |
| 0, | |
| 0 | |
| ], | |
| "bicycle": [ | |
| 0, | |
| 128, | |
| 0 | |
| ], | |
| "bird": [ | |
| 128, | |
| 128, | |
| 0 | |
| ], | |
| "boat": [ | |
| 0, | |
| 0, | |
| 128 | |
| ], | |
| "bottle": [ | |
| 128, | |
| 0, | |
| 128 | |
| ], | |
| "bus": [ | |
| 0, | |
| 128, | |
| 128 | |
| ], | |
| "car": [ | |
| 128, | |
| 128, | |
| 128 | |
| ], | |
| "cat": [ | |
| 64, | |
| 0, | |
| 0 | |
| ], | |
| "chair": [ | |
| 192, | |
| 0, | |
| 0 | |
| ], | |
| "cow": [ | |
| 64, | |
| 128, | |
| 0 | |
| ], | |
| "diningtable": [ | |
| 192, | |
| 128, | |
| 0 | |
| ], | |
| "dog": [ | |
| 64, | |
| 0, | |
| 128 | |
| ], | |
| "horse": [ | |
| 192, | |
| 0, | |
| 128 | |
| ], | |
| "motorbike": [ | |
| 64, | |
| 128, | |
| 128 | |
| ], | |
| "person": [ | |
| 192, | |
| 128, | |
| 128 | |
| ], | |
| "pottedplant": [ | |
| 0, | |
| 64, | |
| 0 | |
| ], | |
| "sheep": [ | |
| 128, | |
| 64, | |
| 0 | |
| ], | |
| "sofa": [ | |
| 0, | |
| 192, | |
| 0 | |
| ], | |
| "train": [ | |
| 128, | |
| 192, | |
| 0 | |
| ], | |
| "tvmonitor": [ | |
| 0, | |
| 64, | |
| 128 | |
| ] | |
| } | |
| colors = np.asarray([cmap_dic[class_name] for class_name in class_names]) | |
| if __name__ == '__main__': | |
| ################################################################################### | |
| # Arguments | |
| ################################################################################### | |
| args = parser.parse_args() | |
| model_dir = create_directory('./experiments/models/') | |
| model_path = model_dir + f'DeepLabv3+@ResNet-50@Fix@GN.pth' | |
| if 'train' in args.domain: | |
| args.tag += '@train' | |
| else: | |
| args.tag += '@' + args.domain | |
| args.tag += '@scale=%s' % args.scales | |
| args.tag += '@iteration=%d' % args.iteration | |
| set_seed(args.seed) | |
| log_func = lambda string='': print(string) | |
| ################################################################################### | |
| # Transform, Dataset, DataLoader | |
| ################################################################################### | |
| imagenet_mean = [0.485, 0.456, 0.406] | |
| imagenet_std = [0.229, 0.224, 0.225] | |
| normalize_fn = Normalize(imagenet_mean, imagenet_std) | |
| # for mIoU | |
| meta_dic = read_json('./data/VOC_2012.json') | |
| ################################################################################### | |
| # Network | |
| ################################################################################### | |
| if args.architecture == 'DeepLabv3+': | |
| model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode, | |
| use_group_norm=args.use_gn) | |
| elif args.architecture == 'Seg_Model': | |
| model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) | |
| elif args.architecture == 'CSeg_Model': | |
| model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) | |
| model.eval() | |
| log_func('[i] Architecture is {}'.format(args.architecture)) | |
| log_func('[i] Total Params: %.2fM' % (calculate_parameters(model))) | |
| log_func() | |
| load_model(model, model_path, parallel=False) | |
| ################################################################################################# | |
| # Evaluation | |
| ################################################################################################# | |
| eval_timer = Timer() | |
| scales = [float(scale) for scale in args.scales.split(',')] | |
| model.eval() | |
| eval_timer.tik() | |
| def inference(images, image_size): | |
| logits = model(images) | |
| logits = resize_for_tensors(logits, image_size) | |
| logits = logits[0] + logits[1].flip(-1) | |
| logits = get_numpy_from_tensor(logits).transpose((1, 2, 0)) | |
| return logits | |
| def predict_image(ori_image): | |
| ori_image = PIL.Image.fromarray(ori_image) | |
| with torch.no_grad(): | |
| ori_w, ori_h = ori_image.size | |
| cams_list = [] | |
| for scale in scales: | |
| image = copy.deepcopy(ori_image) | |
| image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC) | |
| image = normalize_fn(image) | |
| image = image.transpose((2, 0, 1)) | |
| image = torch.from_numpy(image) | |
| flipped_image = image.flip(-1) | |
| images = torch.stack([image, flipped_image]) | |
| cams = inference(images, (ori_h, ori_w)) | |
| cams_list.append(cams) | |
| preds = np.sum(cams_list, axis=0) | |
| preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy() | |
| if args.iteration > 0: | |
| preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration) | |
| pred_mask = np.argmax(preds, axis=0) | |
| else: | |
| pred_mask = np.argmax(preds, axis=-1) | |
| pred_mask = decode_from_colormap(pred_mask, colors)[..., ::-1] | |
| return Image.fromarray(pred_mask.astype(np.uint8)).convert("RGB") | |
| demo = gr.Interface( | |
| fn=predict_image, | |
| inputs="image", | |
| outputs="image" | |
| ) | |
| demo.launch() | |