RDNet / util /util.py
lime-j's picture
Upload 89 files
347b44e
from __future__ import print_function
import math
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import yaml
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity
def get_config(config):
with open(config, 'r') as stream:
return yaml.load(stream)
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = image_numpy.astype(imtype)
if image_numpy.shape[-1] == 6:
image_numpy = np.concatenate([image_numpy[:, :, :3], image_numpy[:, :, 3:]], axis=1)
if image_numpy.shape[-1] == 7:
edge_map = np.tile(image_numpy[:, :, 6:7], (1, 1, 3))
image_numpy = np.concatenate([image_numpy[:, :, :3], image_numpy[:, :, 3:6], edge_map], axis=1)
return image_numpy
def tensor2numpy(image_tensor):
image_numpy = torch.squeeze(image_tensor).cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = image_numpy.astype(np.float32)
return image_numpy
# Get model list for resume
def get_model_list(dirname, key, epoch=None):
if epoch is None:
return os.path.join(dirname, key + '_latest.pt')
if os.path.exists(dirname) is False:
return None
print(dirname, key)
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
os.path.isfile(os.path.join(dirname, f)) and ".pt" in f and 'latest' not in f]
epoch_index = [int(os.path.basename(model_name).split('_')[-2]) for model_name in gen_models if
'latest' not in model_name]
print('[i] available epoch list: %s' % epoch_index, gen_models)
i = epoch_index.index(int(epoch))
return gen_models[i]
def vgg_preprocess(batch):
# normalize using imagenet mean and std
mean = batch.new(batch.size())
std = batch.new(batch.size())
mean[:, 0, :, :] = 0.485
mean[:, 1, :, :] = 0.456
mean[:, 2, :, :] = 0.406
std[:, 0, :, :] = 0.229
std[:, 1, :, :] = 0.224
std[:, 2, :, :] = 0.225
batch = (batch + 1) / 2
batch -= mean
batch = batch / std
return batch
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def set_opt_param(optimizer, key, value):
for group in optimizer.param_groups:
group[key] = value
def vis(x):
if isinstance(x, torch.Tensor):
Image.fromarray(tensor2im(x)).show()
elif isinstance(x, np.ndarray):
Image.fromarray(x.astype(np.uint8)).show()
else:
raise NotImplementedError('vis for type [%s] is not implemented', type(x))
"""tensorboard"""
from tensorboardX import SummaryWriter
from datetime import datetime
def get_summary_writer(log_dir):
if not os.path.exists(log_dir):
os.mkdir(log_dir)
log_dir = os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
if not os.path.exists(log_dir):
os.mkdir(log_dir)
writer = SummaryWriter(log_dir)
return writer
def get_visual(writer,iteration,imgs):
writer.add_image('clean',imgs[0],iteration)
writer.add_image('input', imgs[1],iteration)
#writer.add_image('ref', imgs[1],iteration)
#writer.add_image('input', imgs[2],iteration)
class AverageMeters(object):
def __init__(self, dic=None, total_num=None):
self.dic = dic or {}
# self.total_num = total_num
self.total_num = total_num or {}
def update(self, new_dic):
for key in new_dic:
if not key in self.dic:
self.dic[key] = new_dic[key]
self.total_num[key] = 1
else:
self.dic[key] += new_dic[key]
self.total_num[key] += 1
# self.total_num += 1
def __getitem__(self, key):
return self.dic[key] / self.total_num[key]
def __str__(self):
keys = sorted(self.keys())
res = ''
for key in keys:
res += (key + ': %.4f' % self[key] + ' | ')
return res
def keys(self):
return self.dic.keys()
def write_loss(writer, prefix, avg_meters, iteration):
for key in avg_meters.keys():
meter = avg_meters[key]
writer.add_scalar(
os.path.join(prefix, key), meter, iteration)
"""progress bar"""
import socket
# _, term_width = os.popen('stty size', 'r').read().split()
term_width = 136
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH * current / total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current + 1, total))
if current < total - 1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
def parse_args(args):
str_args = args.split(',')
parsed_args = []
for str_arg in str_args:
arg = int(str_arg)
if arg >= 0:
parsed_args.append(arg)
return parsed_args
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
# nn.init.uniform(m.weight.data, 1.0, 0.02)
m.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)).clamp_(-0.025, 0.025)
nn.init.constant(m.bias.data, 0.0)
def batch_PSNR(img, imclean, data_range):
Img = img.data.cpu().numpy().astype(np.float32)
Iclean = imclean.data.cpu().numpy().astype(np.float32)
PSNR = 0
for i in range(Img.shape[0]):
PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
return PSNR / Img.shape[0]
def batch_SSIM(img, imclean):
Img = img.data.cpu().permute(0, 2, 3, 1).numpy().astype(np.float32)
Iclean = imclean.data.cpu().permute(0, 2, 3, 1).numpy().astype(np.float32)
SSIM = 0
for i in range(Img.shape[0]):
SSIM += structural_similarity(Iclean[i, :, :, :], Img[i, :, :, :], win_size=11,
multichannel=True, data_range=1)
return SSIM / Img.shape[0]
def data_augmentation(image, mode):
out = np.transpose(image, (1, 2, 0))
if mode == 0:
# original
out = out
elif mode == 1:
# flip up and down
out = np.flipud(out)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(out)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(out)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(out, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(out, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(out, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(out, k=3)
out = np.flipud(out)
return np.transpose(out, (2, 0, 1))