RustCoSeg / select_results.py
HirraA's picture
Upload 30 files
168ec29 verified
import os
import argparse
import numpy as np
import cv2
from evaluator import Eval_thread
from dataloader import EvalDataset
import sys
sys.path.append('..')
def main(cfg):
dataset_names = cfg.datasets.split('+')
root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr]
root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet')
print('root_dir_predictions:', root_dir_predictions)
root_dir_prediction = root_dir_predictions[0]
root_dir_good_ones = 'good_ones'
for dataset in dataset_names:
dir_prediction = os.path.join(root_dir_prediction, dataset)
dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset)
dir_gt = os.path.join(cfg.gt_dir, dataset)
loader = EvalDataset(
dir_prediction, # preds
dir_gt, # GT
return_predpath=True,
return_gtpath=True
)
loader_comp = EvalDataset(
dir_prediction_comp, # preds
dir_gt, # GT
return_predpath=True
)
print('Selecting predictions from {}'.format(dir_prediction))
thread = Eval_thread(loader, cuda=cfg.cuda)
s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2)
dir_good_ones = os.path.join(root_dir_good_ones, dataset)
os.makedirs(dir_good_ones, exist_ok=True)
print('have good_ones {}'.format(len(good_ones)))
for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt):
dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2])
os.makedirs(dir_category, exist_ok=True)
save_path = os.path.join(dir_category, good_one.split('/')[-1])
sal_map = cv2.imread(good_one)
sal_map_gt = cv2.imread(good_one_gt)
sal_map_comp = cv2.imread(good_one_comp)
image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg')
image = cv2.imread(image_path)
cv2.imwrite(save_path, sal_map)
split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127
comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp])
save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:]))
cv2.imwrite(save_path_comp, comp)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015')
parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT')
parser.add_argument('--cuda', type=bool, default=True)
config = parser.parse_args()
main(config)