|
|
import os |
|
|
import skimage |
|
|
import argparse |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms.functional as tf |
|
|
|
|
|
from . import model |
|
|
|
|
|
|
|
|
def load_iHarmony4_subset(dataset_dir, mode): |
|
|
if not mode in ['train', 'test']: |
|
|
print('Invalid mode: {0} for the dataset: {1}'.format(mode, dataset_dir)) |
|
|
exit() |
|
|
|
|
|
sample_names = [] |
|
|
with open(os.path.join(dataset_dir, '{0}_{1}.txt'.format(dataset_dir.split('/')[-1], mode)), 'r') as f: |
|
|
sample_names = [_.strip() for _ in f.readlines()] |
|
|
|
|
|
comp_dir = os.path.join(dataset_dir, 'composite_images') |
|
|
mask_dir = os.path.join(dataset_dir, 'masks') |
|
|
real_dir = os.path.join(dataset_dir, 'real_images') |
|
|
|
|
|
samples = [] |
|
|
comp_names = os.listdir(comp_dir) |
|
|
for comp_name in comp_names: |
|
|
if comp_name in sample_names: |
|
|
mask_name = '_'.join(comp_name.split('_')[:-1]) + '.png' |
|
|
real_name = '_'.join(comp_name.split('_')[:-2]) + '.jpg' |
|
|
|
|
|
sample = { |
|
|
'comp': os.path.join(comp_dir, comp_name), |
|
|
'mask': os.path.join(mask_dir, mask_name), |
|
|
'real': os.path.join(real_dir, real_name), |
|
|
} |
|
|
|
|
|
samples.append(sample) |
|
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
def calc_metrics(pred, gt, mask): |
|
|
n, c, h, w = pred.shape |
|
|
assert n == 1 |
|
|
total_pixels = h * w |
|
|
fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy()) |
|
|
|
|
|
pred = torch.clamp(pred * 255, 0, 255) |
|
|
gt = torch.clamp(gt * 255, 0, 255) |
|
|
|
|
|
pred = pred[0].permute(1, 2, 0).cpu().numpy() |
|
|
gt = gt[0].permute(1, 2, 0).cpu().numpy() |
|
|
mask = mask[0].permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
mse = skimage.metrics.mean_squared_error(pred, gt) |
|
|
fmse = skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels |
|
|
psnr = skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min()) |
|
|
ssim = skimage.metrics.structural_similarity(pred, gt, multichannel=True) |
|
|
|
|
|
return mse, fmse, psnr, ssim |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
DATASET_DIR = './dataset' |
|
|
if not os.path.exists(DATASET_DIR): |
|
|
print('Cannot find the dataset dir') |
|
|
exit() |
|
|
|
|
|
|
|
|
DATASETS = { |
|
|
'HCOCO': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HCOCO'), |
|
|
'HFlickr': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HFlickr'), |
|
|
'Hday2night': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/Hday2night'), |
|
|
} |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--pretrained', type=str, default='./pretrained/harmonizer.pth', help='') |
|
|
parser.add_argument('--datasets', type=str, nargs='+', required=True, choices=DATASETS.keys(), help='') |
|
|
parser.add_argument('--metric-size', type=int, default=0, help='') |
|
|
args = parser.parse_known_args()[0] |
|
|
|
|
|
|
|
|
metric_size = (args.metric_size, args.metric_size) if args.metric_size > 0 else None |
|
|
cuda = torch.cuda.is_available() |
|
|
|
|
|
|
|
|
print('\n') |
|
|
print('Evaluation Harmonizer:') |
|
|
print(' - Pretrained Model: {0}'.format(args.pretrained)) |
|
|
print(' - Validation Datasets: {0}'.format(args.datasets)) |
|
|
print(' - Metric Calculation Size: {0}'.format(metric_size if args.metric_size > 0 else 'original')) |
|
|
|
|
|
|
|
|
harmonizer = model.Harmonizer() |
|
|
if cuda: |
|
|
harmonizer = harmonizer.cuda() |
|
|
harmonizer.load_state_dict(torch.load(args.pretrained), strict=True) |
|
|
harmonizer.eval() |
|
|
|
|
|
|
|
|
datasets = {} |
|
|
for d in args.datasets: |
|
|
datasets[d] = load_iHarmony4_subset(DATASETS[d], 'test') |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
for dkey, dvalue in datasets.items(): |
|
|
print('\n') |
|
|
print('================================================================================') |
|
|
print('Validation Dataset: {0}'.format(dkey)) |
|
|
print('--------------------------------------------------------------------------------') |
|
|
metric = {'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0} |
|
|
sample_num = len(dvalue) |
|
|
pbar = tqdm(dvalue, total=sample_num, unit='sample') |
|
|
|
|
|
for i, sample in enumerate(pbar): |
|
|
|
|
|
comp = Image.open(sample['comp']).convert('RGB') |
|
|
mask = Image.open(sample['mask']).convert('1') |
|
|
image = Image.open(sample['real']).convert('RGB') |
|
|
|
|
|
|
|
|
_comp = tf.to_tensor(comp)[None, ...] |
|
|
_mask = tf.to_tensor(mask)[None, ...] |
|
|
_image = tf.to_tensor(image)[None, ...] |
|
|
if cuda: |
|
|
_comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
arguments = harmonizer.predict_arguments(_comp, _mask) |
|
|
|
|
|
|
|
|
if metric_size is not None: |
|
|
_comp = tf.to_tensor(tf.resize(comp, metric_size))[None, ...] |
|
|
_mask = tf.to_tensor(tf.resize(mask, metric_size))[None, ...] |
|
|
_image = tf.to_tensor(tf.resize(image, metric_size))[None, ...] |
|
|
if cuda: |
|
|
_comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
|
_harmonized = harmonizer.restore_image(_comp, _mask, arguments)[-1] |
|
|
|
|
|
|
|
|
mse, fmse, psnr, ssim = calc_metrics(_harmonized, _image, _mask) |
|
|
|
|
|
metric['MSE'] += mse |
|
|
metric['fMSE'] += fmse |
|
|
metric['PSNR'] += psnr |
|
|
metric['SSIM'] += ssim |
|
|
pbar.set_description('MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format( |
|
|
metric['MSE']/(i+1), metric['fMSE']/(i+1), metric['PSNR']/(i+1), metric['SSIM']/(i+1))) |
|
|
|
|
|
print('--------------------------------------------------------------------------------') |
|
|
print('{0} - MSE: {1:.4f} fMSE: {2:.4f} PSNR: {3:.4f} SSIM: {4:.4f}'.format( |
|
|
dkey, metric['MSE']/sample_num, metric['fMSE']/sample_num, metric['PSNR']/sample_num, metric['SSIM']/sample_num)) |
|
|
print('================================================================================') |
|
|
|
|
|
metrics[dkey] = metric |
|
|
|
|
|
sample_num = sum([len(dvalue) for dvalue in datasets.values()]) |
|
|
mse = sum([metric['MSE'] for metric in metrics.values()]) / sample_num |
|
|
fmse = sum([metric['fMSE'] for metric in metrics.values()]) / sample_num |
|
|
psnr = sum([metric['PSNR'] for metric in metrics.values()]) / sample_num |
|
|
ssim = sum([metric['SSIM'] for metric in metrics.values()]) / sample_num |
|
|
|
|
|
print('\n') |
|
|
print('================================================================================') |
|
|
print('All - MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format(mse, fmse, psnr, ssim)) |
|
|
print('================================================================================') |
|
|
print('\n') |
|
|
|