| import os |
| import argparse |
| import numpy as np |
| import cv2 |
|
|
| from evaluator import Eval_thread |
| from dataloader import EvalDataset |
|
|
| import sys |
| sys.path.append('..') |
|
|
|
|
| def main(cfg): |
| dataset_names = cfg.datasets.split('+') |
| root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr] |
| root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet') |
| print('root_dir_predictions:', root_dir_predictions) |
| root_dir_prediction = root_dir_predictions[0] |
| root_dir_good_ones = 'good_ones' |
| for dataset in dataset_names: |
| dir_prediction = os.path.join(root_dir_prediction, dataset) |
| dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset) |
| dir_gt = os.path.join(cfg.gt_dir, dataset) |
| loader = EvalDataset( |
| dir_prediction, |
| dir_gt, |
| return_predpath=True, |
| return_gtpath=True |
| ) |
| loader_comp = EvalDataset( |
| dir_prediction_comp, |
| dir_gt, |
| return_predpath=True |
| ) |
| print('Selecting predictions from {}'.format(dir_prediction)) |
| thread = Eval_thread(loader, cuda=cfg.cuda) |
| s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2) |
| dir_good_ones = os.path.join(root_dir_good_ones, dataset) |
| os.makedirs(dir_good_ones, exist_ok=True) |
| print('have good_ones {}'.format(len(good_ones))) |
| for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt): |
| dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2]) |
| os.makedirs(dir_category, exist_ok=True) |
| save_path = os.path.join(dir_category, good_one.split('/')[-1]) |
| sal_map = cv2.imread(good_one) |
| sal_map_gt = cv2.imread(good_one_gt) |
| sal_map_comp = cv2.imread(good_one_comp) |
| image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg') |
| image = cv2.imread(image_path) |
| cv2.imwrite(save_path, sal_map) |
| split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127 |
| comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp]) |
| save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:])) |
| cv2.imwrite(save_path_comp, comp) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015') |
| parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT') |
| parser.add_argument('--cuda', type=bool, default=True) |
| config = parser.parse_args() |
| main(config) |