File size: 2,836 Bytes
168ec29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import argparse
import numpy as np
import cv2

from evaluator import Eval_thread
from dataloader import EvalDataset

import sys
sys.path.append('..')


def main(cfg):
    dataset_names = cfg.datasets.split('+')
    root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr]
    root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet')
    print('root_dir_predictions:', root_dir_predictions)
    root_dir_prediction = root_dir_predictions[0]
    root_dir_good_ones = 'good_ones'
    for dataset in dataset_names:
        dir_prediction = os.path.join(root_dir_prediction, dataset)
        dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset)
        dir_gt = os.path.join(cfg.gt_dir, dataset)
        loader = EvalDataset(
            dir_prediction,        # preds
            dir_gt,                   # GT
            return_predpath=True,
            return_gtpath=True
        )
        loader_comp = EvalDataset(
            dir_prediction_comp,        # preds
            dir_gt,                   # GT
            return_predpath=True
        )
        print('Selecting predictions from {}'.format(dir_prediction))
        thread = Eval_thread(loader, cuda=cfg.cuda)
        s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2)
        dir_good_ones = os.path.join(root_dir_good_ones, dataset)
        os.makedirs(dir_good_ones, exist_ok=True)
        print('have good_ones {}'.format(len(good_ones)))
        for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt):
            dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2])
            os.makedirs(dir_category, exist_ok=True)
            save_path = os.path.join(dir_category, good_one.split('/')[-1])
            sal_map = cv2.imread(good_one)
            sal_map_gt = cv2.imread(good_one_gt)
            sal_map_comp = cv2.imread(good_one_comp)
            image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg')
            image = cv2.imread(image_path)
            cv2.imwrite(save_path, sal_map)
            split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127
            comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp])
            save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:]))
            cv2.imwrite(save_path_comp, comp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015')
    parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT')
    parser.add_argument('--cuda', type=bool, default=True)
    config = parser.parse_args()
    main(config)