|
|
import os
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import math
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Visualizer(object):
|
|
|
"""docstring for Visualizer"""
|
|
|
def __init__(self):
|
|
|
super(Visualizer, self).__init__()
|
|
|
|
|
|
def initialize(self, opt):
|
|
|
self.opt = opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.display_id = self.opt.visdom_display_id
|
|
|
if self.display_id > 0:
|
|
|
import visdom
|
|
|
self.ncols = 8
|
|
|
self.vis = visdom.Visdom(server="http://localhost", port=self.opt.visdom_port, env=self.opt.visdom_env)
|
|
|
|
|
|
def throw_visdom_connection_error(self):
|
|
|
print('\n\nno visdom server.')
|
|
|
exit(1)
|
|
|
|
|
|
def print_losses_info(self, info_dict):
|
|
|
msg = '[{}][Epoch: {:0>3}/{:0>3}; Images: {:0>4}/{:0>4}; Time: {:.3f}s/Batch({}); LR: {:.7f}] '.format(
|
|
|
self.opt.name, info_dict['epoch'], info_dict['epoch_len'],
|
|
|
info_dict['epoch_steps'], info_dict['epoch_steps_len'],
|
|
|
info_dict['step_time'], self.opt.batch_size, info_dict['cur_lr'])
|
|
|
for k, v in info_dict['losses'].items():
|
|
|
msg += '| {}: {:.4f} '.format(k, v)
|
|
|
msg += '|'
|
|
|
print(msg)
|
|
|
with open(info_dict['log_path'], 'a+') as f:
|
|
|
f.write(msg + '\n')
|
|
|
|
|
|
def display_current_losses(self, epoch, counter_ratio, losses_dict):
|
|
|
if not hasattr(self, 'plot_data'):
|
|
|
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses_dict.keys())}
|
|
|
self.plot_data['X'].append(epoch + counter_ratio)
|
|
|
self.plot_data['Y'].append([losses_dict[k] for k in self.plot_data['legend']])
|
|
|
try:
|
|
|
self.vis.line(
|
|
|
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
|
|
Y=np.array(self.plot_data['Y']),
|
|
|
opts={
|
|
|
'title': self.opt.name + ' loss over time',
|
|
|
'legend':self.plot_data['legend'],
|
|
|
'xlabel':'epoch',
|
|
|
'ylabel':'loss'},
|
|
|
win=self.display_id)
|
|
|
except ConnectionError:
|
|
|
self.throw_visdom_connection_error()
|
|
|
|
|
|
def display_online_results(self, visuals, epoch):
|
|
|
win_id = self.display_id + 24
|
|
|
images = []
|
|
|
labels = []
|
|
|
for label, image in visuals.items():
|
|
|
if 'mask' in label:
|
|
|
image = (image - 0.5) / 0.5
|
|
|
image_numpy = self.tensor2im(image)
|
|
|
images.append(image_numpy.transpose([2, 0, 1]))
|
|
|
labels.append(label)
|
|
|
try:
|
|
|
title = ' || '.join(labels)
|
|
|
self.vis.images(images, nrow=self.ncols, win=win_id,
|
|
|
padding=5, opts=dict(title=title))
|
|
|
except ConnectionError:
|
|
|
self.throw_visdom_connection_error()
|
|
|
|
|
|
|
|
|
def tensor2im(self, input_image, imtype=np.uint8):
|
|
|
if isinstance(input_image, torch.Tensor):
|
|
|
image_tensor = input_image.data
|
|
|
else:
|
|
|
return input_image
|
|
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
|
|
im = self.numpy2im(image_numpy, imtype).resize((80, 80), Image.ANTIALIAS)
|
|
|
return np.array(im)
|
|
|
|
|
|
def numpy2im(self, image_numpy, imtype=np.uint8):
|
|
|
if image_numpy.shape[0] == 1:
|
|
|
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
|
|
|
|
|
|
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) / 2. + 0.5) * 255.0
|
|
|
|
|
|
image_numpy = image_numpy.astype(imtype)
|
|
|
im = Image.fromarray(image_numpy)
|
|
|
|
|
|
return im
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|