File size: 3,597 Bytes
c1651d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import numpy as np
import argparse
import glob
import os
import sys
import torch
import cv2
import random
import time
import multiprocessing.pool as mpp
import multiprocessing as mp
SEED = 66
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def label2rgb(mask, mask_pred):
real_1 = (mask == 1)
real_0 = (mask == 0)
pred_1 = (mask_pred == 1)
pred_0 = (mask_pred == 0)
TP = np.logical_and(real_1, pred_1)
TN = np.logical_and(real_0, pred_0)
FN = np.logical_and(real_1, pred_0)
FP = np.logical_and(real_0, pred_1)
mask_TP = TP[np.newaxis, :, :]
mask_TN = TN[np.newaxis, :, :]
mask_FN = FN[np.newaxis, :, :]
mask_FP = FP[np.newaxis, :, :]
h, w = mask.shape[0], mask.shape[1]
mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
mask_rgb[np.all(mask_TP, axis=0)] = [255, 255, 255] # TP
mask_rgb[np.all(mask_TN, axis=0)] = [0, 0, 0] # TN
mask_rgb[np.all(mask_FN, axis=0)] = [0, 255, 0] # FN
mask_rgb[np.all(mask_FP, axis=0)] = [255, 0, 0] # FP
return mask_rgb
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="Vaihingen")
parser.add_argument("--mask-dir", default="data/Test/masks")
parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb")
return parser.parse_args()
def mask_save(inp):
(mask, mask_pred, masks_output_dir, file_name) = inp
out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name))
label = label2rgb(mask.copy(), mask_pred.copy())
rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
cv2.imwrite(out_mask_path, rgb_label)
# def get_rgb(inp):
# (mask_path, masks_output_dir,dataset) = inp
# mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
# mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
# mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
# if dataset == "LoveDA":
# rgb_label = loveda_label2rgb(mask.copy())
# elif dataset == "Vaihingen":
# rgb_label = vaihingen_label2rgb(mask.copy())
# elif dataset == "Potsdam":
# rgb_label = potsdam_label2rgb(mask.copy())
# elif dataset == "uavid":
# rgb_label = uavid_label2rgb(mask.copy())
# else: return
# #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR)
# out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename))
# rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB)
# cv2.imwrite(out_mask_path_rgb, rgb_label)
# if __name__ == '__main__':
# base_path = "/home/xwma/lrr/rssegmentation/"
# args = parse_args()
# dataset = args.dataset
# seed_everything(SEED)
# masks_dir = args.mask_dir
# masks_output_dir = args.output_mask_dir
# masks_dir = base_path + masks_dir
# masks_output_dir = base_path + masks_output_dir
# mask_paths = glob.glob(os.path.join(masks_dir, "*.png"))
# inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths]
# if not os.path.exists(masks_output_dir):
# os.makedirs(masks_output_dir)
# t0 = time.time()
# mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp)
# t1 = time.time()
# split_time = t1 - t0
# print('images spliting spends: {} s'.format(split_time))
|