|
|
from __future__ import print_function |
|
|
|
|
|
import math |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
from PIL import Image |
|
|
from skimage.metrics import peak_signal_noise_ratio as compare_psnr |
|
|
from skimage.metrics import structural_similarity |
|
|
|
|
|
def get_config(config): |
|
|
with open(config, 'r') as stream: |
|
|
return yaml.load(stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor2im(image_tensor, imtype=np.uint8): |
|
|
image_numpy = image_tensor[0].cpu().float().numpy() |
|
|
if image_numpy.shape[0] == 1: |
|
|
image_numpy = np.tile(image_numpy, (3, 1, 1)) |
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 |
|
|
image_numpy = image_numpy.astype(imtype) |
|
|
if image_numpy.shape[-1] == 6: |
|
|
image_numpy = np.concatenate([image_numpy[:, :, :3], image_numpy[:, :, 3:]], axis=1) |
|
|
if image_numpy.shape[-1] == 7: |
|
|
edge_map = np.tile(image_numpy[:, :, 6:7], (1, 1, 3)) |
|
|
image_numpy = np.concatenate([image_numpy[:, :, :3], image_numpy[:, :, 3:6], edge_map], axis=1) |
|
|
return image_numpy |
|
|
|
|
|
|
|
|
def tensor2numpy(image_tensor): |
|
|
image_numpy = torch.squeeze(image_tensor).cpu().float().numpy() |
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 |
|
|
image_numpy = image_numpy.astype(np.float32) |
|
|
return image_numpy |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(dirname, key, epoch=None): |
|
|
if epoch is None: |
|
|
return os.path.join(dirname, key + '_latest.pt') |
|
|
if os.path.exists(dirname) is False: |
|
|
return None |
|
|
|
|
|
print(dirname, key) |
|
|
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if |
|
|
os.path.isfile(os.path.join(dirname, f)) and ".pt" in f and 'latest' not in f] |
|
|
epoch_index = [int(os.path.basename(model_name).split('_')[-2]) for model_name in gen_models if |
|
|
'latest' not in model_name] |
|
|
print('[i] available epoch list: %s' % epoch_index, gen_models) |
|
|
i = epoch_index.index(int(epoch)) |
|
|
|
|
|
return gen_models[i] |
|
|
|
|
|
|
|
|
def vgg_preprocess(batch): |
|
|
|
|
|
mean = batch.new(batch.size()) |
|
|
std = batch.new(batch.size()) |
|
|
mean[:, 0, :, :] = 0.485 |
|
|
mean[:, 1, :, :] = 0.456 |
|
|
mean[:, 2, :, :] = 0.406 |
|
|
std[:, 0, :, :] = 0.229 |
|
|
std[:, 1, :, :] = 0.224 |
|
|
std[:, 2, :, :] = 0.225 |
|
|
batch = (batch + 1) / 2 |
|
|
batch -= mean |
|
|
batch = batch / std |
|
|
return batch |
|
|
|
|
|
|
|
|
def diagnose_network(net, name='network'): |
|
|
mean = 0.0 |
|
|
count = 0 |
|
|
for param in net.parameters(): |
|
|
if param.grad is not None: |
|
|
mean += torch.mean(torch.abs(param.grad.data)) |
|
|
count += 1 |
|
|
if count > 0: |
|
|
mean = mean / count |
|
|
print(name) |
|
|
print(mean) |
|
|
|
|
|
|
|
|
def save_image(image_numpy, image_path): |
|
|
image_pil = Image.fromarray(image_numpy) |
|
|
image_pil.save(image_path) |
|
|
|
|
|
|
|
|
def print_numpy(x, val=True, shp=False): |
|
|
x = x.astype(np.float64) |
|
|
if shp: |
|
|
print('shape,', x.shape) |
|
|
if val: |
|
|
x = x.flatten() |
|
|
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( |
|
|
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) |
|
|
|
|
|
|
|
|
def mkdirs(paths): |
|
|
if isinstance(paths, list) and not isinstance(paths, str): |
|
|
for path in paths: |
|
|
mkdir(path) |
|
|
else: |
|
|
mkdir(paths) |
|
|
|
|
|
|
|
|
def mkdir(path): |
|
|
if not os.path.exists(path): |
|
|
os.makedirs(path) |
|
|
|
|
|
|
|
|
def set_opt_param(optimizer, key, value): |
|
|
for group in optimizer.param_groups: |
|
|
group[key] = value |
|
|
|
|
|
|
|
|
def vis(x): |
|
|
if isinstance(x, torch.Tensor): |
|
|
Image.fromarray(tensor2im(x)).show() |
|
|
elif isinstance(x, np.ndarray): |
|
|
Image.fromarray(x.astype(np.uint8)).show() |
|
|
else: |
|
|
raise NotImplementedError('vis for type [%s] is not implemented', type(x)) |
|
|
|
|
|
|
|
|
"""tensorboard""" |
|
|
from tensorboardX import SummaryWriter |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
def get_summary_writer(log_dir): |
|
|
if not os.path.exists(log_dir): |
|
|
os.mkdir(log_dir) |
|
|
log_dir = os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) |
|
|
if not os.path.exists(log_dir): |
|
|
os.mkdir(log_dir) |
|
|
writer = SummaryWriter(log_dir) |
|
|
return writer |
|
|
def get_visual(writer,iteration,imgs): |
|
|
writer.add_image('clean',imgs[0],iteration) |
|
|
writer.add_image('input', imgs[1],iteration) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AverageMeters(object): |
|
|
def __init__(self, dic=None, total_num=None): |
|
|
self.dic = dic or {} |
|
|
|
|
|
self.total_num = total_num or {} |
|
|
|
|
|
def update(self, new_dic): |
|
|
for key in new_dic: |
|
|
if not key in self.dic: |
|
|
self.dic[key] = new_dic[key] |
|
|
self.total_num[key] = 1 |
|
|
else: |
|
|
self.dic[key] += new_dic[key] |
|
|
self.total_num[key] += 1 |
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
return self.dic[key] / self.total_num[key] |
|
|
|
|
|
def __str__(self): |
|
|
keys = sorted(self.keys()) |
|
|
res = '' |
|
|
for key in keys: |
|
|
res += (key + ': %.4f' % self[key] + ' | ') |
|
|
return res |
|
|
|
|
|
def keys(self): |
|
|
return self.dic.keys() |
|
|
|
|
|
|
|
|
def write_loss(writer, prefix, avg_meters, iteration): |
|
|
for key in avg_meters.keys(): |
|
|
meter = avg_meters[key] |
|
|
writer.add_scalar( |
|
|
os.path.join(prefix, key), meter, iteration) |
|
|
|
|
|
|
|
|
"""progress bar""" |
|
|
import socket |
|
|
|
|
|
|
|
|
term_width = 136 |
|
|
|
|
|
TOTAL_BAR_LENGTH = 65. |
|
|
last_time = time.time() |
|
|
begin_time = last_time |
|
|
|
|
|
|
|
|
def progress_bar(current, total, msg=None): |
|
|
global last_time, begin_time |
|
|
if current == 0: |
|
|
begin_time = time.time() |
|
|
|
|
|
cur_len = int(TOTAL_BAR_LENGTH * current / total) |
|
|
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 |
|
|
|
|
|
sys.stdout.write(' [') |
|
|
for i in range(cur_len): |
|
|
sys.stdout.write('=') |
|
|
sys.stdout.write('>') |
|
|
for i in range(rest_len): |
|
|
sys.stdout.write('.') |
|
|
sys.stdout.write(']') |
|
|
|
|
|
cur_time = time.time() |
|
|
step_time = cur_time - last_time |
|
|
last_time = cur_time |
|
|
tot_time = cur_time - begin_time |
|
|
|
|
|
L = [] |
|
|
L.append(' Step: %s' % format_time(step_time)) |
|
|
L.append(' | Tot: %s' % format_time(tot_time)) |
|
|
if msg: |
|
|
L.append(' | ' + msg) |
|
|
|
|
|
msg = ''.join(L) |
|
|
sys.stdout.write(msg) |
|
|
for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): |
|
|
sys.stdout.write(' ') |
|
|
|
|
|
|
|
|
for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): |
|
|
sys.stdout.write('\b') |
|
|
sys.stdout.write(' %d/%d ' % (current + 1, total)) |
|
|
|
|
|
if current < total - 1: |
|
|
sys.stdout.write('\r') |
|
|
else: |
|
|
sys.stdout.write('\n') |
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
|
def format_time(seconds): |
|
|
days = int(seconds / 3600 / 24) |
|
|
seconds = seconds - days * 3600 * 24 |
|
|
hours = int(seconds / 3600) |
|
|
seconds = seconds - hours * 3600 |
|
|
minutes = int(seconds / 60) |
|
|
seconds = seconds - minutes * 60 |
|
|
secondsf = int(seconds) |
|
|
seconds = seconds - secondsf |
|
|
millis = int(seconds * 1000) |
|
|
|
|
|
f = '' |
|
|
i = 1 |
|
|
if days > 0: |
|
|
f += str(days) + 'D' |
|
|
i += 1 |
|
|
if hours > 0 and i <= 2: |
|
|
f += str(hours) + 'h' |
|
|
i += 1 |
|
|
if minutes > 0 and i <= 2: |
|
|
f += str(minutes) + 'm' |
|
|
i += 1 |
|
|
if secondsf > 0 and i <= 2: |
|
|
f += str(secondsf) + 's' |
|
|
i += 1 |
|
|
if millis > 0 and i <= 2: |
|
|
f += str(millis) + 'ms' |
|
|
i += 1 |
|
|
if f == '': |
|
|
f = '0ms' |
|
|
return f |
|
|
|
|
|
|
|
|
def parse_args(args): |
|
|
str_args = args.split(',') |
|
|
parsed_args = [] |
|
|
for str_arg in str_args: |
|
|
arg = int(str_arg) |
|
|
if arg >= 0: |
|
|
parsed_args.append(arg) |
|
|
return parsed_args |
|
|
|
|
|
|
|
|
def weights_init_kaiming(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find('Conv') != -1: |
|
|
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') |
|
|
elif classname.find('Linear') != -1: |
|
|
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') |
|
|
elif classname.find('BatchNorm') != -1: |
|
|
|
|
|
m.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)).clamp_(-0.025, 0.025) |
|
|
nn.init.constant(m.bias.data, 0.0) |
|
|
|
|
|
|
|
|
def batch_PSNR(img, imclean, data_range): |
|
|
Img = img.data.cpu().numpy().astype(np.float32) |
|
|
Iclean = imclean.data.cpu().numpy().astype(np.float32) |
|
|
PSNR = 0 |
|
|
for i in range(Img.shape[0]): |
|
|
PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range) |
|
|
return PSNR / Img.shape[0] |
|
|
|
|
|
|
|
|
def batch_SSIM(img, imclean): |
|
|
Img = img.data.cpu().permute(0, 2, 3, 1).numpy().astype(np.float32) |
|
|
Iclean = imclean.data.cpu().permute(0, 2, 3, 1).numpy().astype(np.float32) |
|
|
SSIM = 0 |
|
|
|
|
|
for i in range(Img.shape[0]): |
|
|
SSIM += structural_similarity(Iclean[i, :, :, :], Img[i, :, :, :], win_size=11, |
|
|
multichannel=True, data_range=1) |
|
|
return SSIM / Img.shape[0] |
|
|
|
|
|
|
|
|
def data_augmentation(image, mode): |
|
|
out = np.transpose(image, (1, 2, 0)) |
|
|
if mode == 0: |
|
|
|
|
|
out = out |
|
|
elif mode == 1: |
|
|
|
|
|
out = np.flipud(out) |
|
|
elif mode == 2: |
|
|
|
|
|
out = np.rot90(out) |
|
|
elif mode == 3: |
|
|
|
|
|
out = np.rot90(out) |
|
|
out = np.flipud(out) |
|
|
elif mode == 4: |
|
|
|
|
|
out = np.rot90(out, k=2) |
|
|
elif mode == 5: |
|
|
|
|
|
out = np.rot90(out, k=2) |
|
|
out = np.flipud(out) |
|
|
elif mode == 6: |
|
|
|
|
|
out = np.rot90(out, k=3) |
|
|
elif mode == 7: |
|
|
|
|
|
out = np.rot90(out, k=3) |
|
|
out = np.flipud(out) |
|
|
return np.transpose(out, (2, 0, 1)) |
|
|
|