Env_mixer / src /val_harmonizer.py
Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import os
import skimage
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from . import model
def load_iHarmony4_subset(dataset_dir, mode):
if not mode in ['train', 'test']:
print('Invalid mode: {0} for the dataset: {1}'.format(mode, dataset_dir))
exit()
sample_names = []
with open(os.path.join(dataset_dir, '{0}_{1}.txt'.format(dataset_dir.split('/')[-1], mode)), 'r') as f:
sample_names = [_.strip() for _ in f.readlines()]
comp_dir = os.path.join(dataset_dir, 'composite_images')
mask_dir = os.path.join(dataset_dir, 'masks')
real_dir = os.path.join(dataset_dir, 'real_images')
samples = []
comp_names = os.listdir(comp_dir)
for comp_name in comp_names:
if comp_name in sample_names:
mask_name = '_'.join(comp_name.split('_')[:-1]) + '.png'
real_name = '_'.join(comp_name.split('_')[:-2]) + '.jpg'
sample = {
'comp': os.path.join(comp_dir, comp_name),
'mask': os.path.join(mask_dir, mask_name),
'real': os.path.join(real_dir, real_name),
}
samples.append(sample)
return samples
def calc_metrics(pred, gt, mask):
n, c, h, w = pred.shape
assert n == 1
total_pixels = h * w
fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy())
pred = torch.clamp(pred * 255, 0, 255)
gt = torch.clamp(gt * 255, 0, 255)
pred = pred[0].permute(1, 2, 0).cpu().numpy()
gt = gt[0].permute(1, 2, 0).cpu().numpy()
mask = mask[0].permute(1, 2, 0).cpu().numpy()
mse = skimage.metrics.mean_squared_error(pred, gt)
fmse = skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels
psnr = skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min())
ssim = skimage.metrics.structural_similarity(pred, gt, multichannel=True)
return mse, fmse, psnr, ssim
if __name__ == '__main__':
# check dataset dir
DATASET_DIR = './dataset'
if not os.path.exists(DATASET_DIR):
print('Cannot find the dataset dir')
exit()
# supported image harmonization validation datasets
DATASETS = {
'HCOCO': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HCOCO'),
'HFlickr': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HFlickr'),
'Hday2night': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/Hday2night'),
}
# define cmd arguments
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained', type=str, default='./pretrained/harmonizer.pth', help='')
parser.add_argument('--datasets', type=str, nargs='+', required=True, choices=DATASETS.keys(), help='')
parser.add_argument('--metric-size', type=int, default=0, help='')
args = parser.parse_known_args()[0]
# pre-process the required arguments
metric_size = (args.metric_size, args.metric_size) if args.metric_size > 0 else None
cuda = torch.cuda.is_available()
# print arguments
print('\n')
print('Evaluation Harmonizer:')
print(' - Pretrained Model: {0}'.format(args.pretrained))
print(' - Validation Datasets: {0}'.format(args.datasets))
print(' - Metric Calculation Size: {0}'.format(metric_size if args.metric_size > 0 else 'original'))
# create/load the harmonizer model
harmonizer = model.Harmonizer()
if cuda:
harmonizer = harmonizer.cuda()
harmonizer.load_state_dict(torch.load(args.pretrained), strict=True)
harmonizer.eval()
# load validation datasets
datasets = {}
for d in args.datasets:
datasets[d] = load_iHarmony4_subset(DATASETS[d], 'test')
# validation
metrics = {}
for dkey, dvalue in datasets.items():
print('\n')
print('================================================================================')
print('Validation Dataset: {0}'.format(dkey))
print('--------------------------------------------------------------------------------')
metric = {'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}
sample_num = len(dvalue)
pbar = tqdm(dvalue, total=sample_num, unit='sample')
for i, sample in enumerate(pbar):
# load inputs
comp = Image.open(sample['comp']).convert('RGB')
mask = Image.open(sample['mask']).convert('1')
image = Image.open(sample['real']).convert('RGB')
# prepare inputs for argument prediction
_comp = tf.to_tensor(comp)[None, ...]
_mask = tf.to_tensor(mask)[None, ...]
_image = tf.to_tensor(image)[None, ...]
if cuda:
_comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda()
# predict arguments
with torch.no_grad():
arguments = harmonizer.predict_arguments(_comp, _mask)
# prepare inputs for metric calculation
if metric_size is not None:
_comp = tf.to_tensor(tf.resize(comp, metric_size))[None, ...]
_mask = tf.to_tensor(tf.resize(mask, metric_size))[None, ...]
_image = tf.to_tensor(tf.resize(image, metric_size))[None, ...]
if cuda:
_comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda()
with torch.no_grad():
_harmonized = harmonizer.restore_image(_comp, _mask, arguments)[-1]
# calculate metrics
mse, fmse, psnr, ssim = calc_metrics(_harmonized, _image, _mask)
metric['MSE'] += mse
metric['fMSE'] += fmse
metric['PSNR'] += psnr
metric['SSIM'] += ssim
pbar.set_description('MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format(
metric['MSE']/(i+1), metric['fMSE']/(i+1), metric['PSNR']/(i+1), metric['SSIM']/(i+1)))
print('--------------------------------------------------------------------------------')
print('{0} - MSE: {1:.4f} fMSE: {2:.4f} PSNR: {3:.4f} SSIM: {4:.4f}'.format(
dkey, metric['MSE']/sample_num, metric['fMSE']/sample_num, metric['PSNR']/sample_num, metric['SSIM']/sample_num))
print('================================================================================')
metrics[dkey] = metric
sample_num = sum([len(dvalue) for dvalue in datasets.values()])
mse = sum([metric['MSE'] for metric in metrics.values()]) / sample_num
fmse = sum([metric['fMSE'] for metric in metrics.values()]) / sample_num
psnr = sum([metric['PSNR'] for metric in metrics.values()]) / sample_num
ssim = sum([metric['SSIM'] for metric in metrics.values()]) / sample_num
print('\n')
print('================================================================================')
print('All - MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format(mse, fmse, psnr, ssim))
print('================================================================================')
print('\n')