InPeerReview's picture
Upload 9 files
032c113 verified
raw
history blame
4.65 kB
import torch
from torch import nn
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp
from train import *
import argparse
from utils.config import Config
from tools.mask_convert import mask_save
def get_args():
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
cfg = Config.fromfile(args.config)
ckpt = args.ckpt
if ckpt is None:
ckpt = cfg.test_ckpt_path
assert ckpt is not None
if args.output_dir:
base_dir = args.output_dir
else:
base_dir = os.path.dirname(ckpt)
masks_output_dir = os.path.join(base_dir, "mask_rgb")
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
model = model.to('cuda')
model.eval()
metric_cfg_1 = cfg.metric_cfg1
metric_cfg_2 = cfg.metric_cfg2
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
results = []
with torch.no_grad():
test_loader = build_dataloader(cfg.dataset_config, mode='test')
for input in tqdm(test_loader):
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
if cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
if cfg.argmax:
pred = raw_predictions.argmax(dim=1)
else:
if cfg.net == 'maskcd':
pred = raw_predictions[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = raw_predictions.squeeze(1)
pred = pred > 0.5
test_oa(pred, mask)
test_iou(pred, mask)
test_prec(pred, mask)
test_f1(pred, mask)
test_recall(pred, mask)
for i in range(raw_predictions.shape[0]):
mask_real = mask[i].cpu().numpy()
mask_pred = pred[i].cpu().numpy()
mask_name = str(img_id[i])
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
metrics = [test_prec.compute(),
test_recall.compute(),
test_f1.compute(),
test_iou.compute()]
total_metrics = [test_oa.compute().cpu().numpy(),
np.mean([item.cpu() for item in metrics[0]]),
np.mean([item.cpu() for item in metrics[1]]),
np.mean([item.cpu() for item in metrics[2]]),
np.mean([item.cpu() for item in metrics[3]])]
result_table = prettytable.PrettyTable()
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
for i in range(2):
item = [i, '--']
for j in range(len(metrics)):
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
result_table.add_row(item)
total = [np.round(v, 4) for v in total_metrics]
total.insert(0, 'total')
result_table.add_row(total)
print(result_table)
file_name = os.path.join(base_dir, "test_res.txt")
f = open(file_name,"a")
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
f.write(current_time+'\n')
f.write(str(result_table)+'\n')
if not os.path.exists(masks_output_dir):
os.makedirs(masks_output_dir)
print(masks_output_dir)
t0 = time.time()
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))