Matharrr commited on
Commit
1d92877
·
1 Parent(s): ee70a2d

first commit

Browse files
Files changed (6) hide show
  1. util/__init__.py +2 -0
  2. util/get_data.py +110 -0
  3. util/html.py +86 -0
  4. util/image_pool.py +54 -0
  5. util/util.py +166 -0
  6. util/visualizer.py +242 -0
util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
2
+ from util import *
util/get_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """A Python script for downloading CycleGAN or pix2pix datasets.
13
+
14
+ Parameters:
15
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16
+ verbose (bool) -- If True, print additional information.
17
+
18
+ Examples:
19
+ >>> from util.get_data import GetData
20
+ >>> gd = GetData(technique='cyclegan')
21
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22
+
23
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24
+ and 'scripts/download_cyclegan_model.sh'.
25
+ """
26
+
27
+ def __init__(self, technique='cyclegan', verbose=True):
28
+ url_dict = {
29
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31
+ }
32
+ self.url = url_dict.get(technique.lower())
33
+ self._verbose = verbose
34
+
35
+ def _print(self, text):
36
+ if self._verbose:
37
+ print(text)
38
+
39
+ @staticmethod
40
+ def _get_options(r):
41
+ soup = BeautifulSoup(r.text, 'lxml')
42
+ options = [h.text for h in soup.find_all('a', href=True)
43
+ if h.text.endswith(('.zip', 'tar.gz'))]
44
+ return options
45
+
46
+ def _present_options(self):
47
+ r = requests.get(self.url)
48
+ options = self._get_options(r)
49
+ print('Options:\n')
50
+ for i, o in enumerate(options):
51
+ print("{0}: {1}".format(i, o))
52
+ choice = input("\nPlease enter the number of the "
53
+ "dataset above you wish to download:")
54
+ return options[int(choice)]
55
+
56
+ def _download_data(self, dataset_url, save_path):
57
+ if not isdir(save_path):
58
+ os.makedirs(save_path)
59
+
60
+ base = basename(dataset_url)
61
+ temp_save_path = join(save_path, base)
62
+
63
+ with open(temp_save_path, "wb") as f:
64
+ r = requests.get(dataset_url)
65
+ f.write(r.content)
66
+
67
+ if base.endswith('.tar.gz'):
68
+ obj = tarfile.open(temp_save_path)
69
+ elif base.endswith('.zip'):
70
+ obj = ZipFile(temp_save_path, 'r')
71
+ else:
72
+ raise ValueError("Unknown File Type: {0}.".format(base))
73
+
74
+ self._print("Unpacking Data...")
75
+ obj.extractall(save_path)
76
+ obj.close()
77
+ os.remove(temp_save_path)
78
+
79
+ def get(self, save_path, dataset=None):
80
+ """
81
+
82
+ Download a dataset.
83
+
84
+ Parameters:
85
+ save_path (str) -- A directory to save the data to.
86
+ dataset (str) -- (optional). A specific dataset to download.
87
+ Note: this must include the file extension.
88
+ If None, options will be presented for you
89
+ to choose from.
90
+
91
+ Returns:
92
+ save_path_full (str) -- the absolute path to the downloaded data.
93
+
94
+ """
95
+ if dataset is None:
96
+ selected_dataset = self._present_options()
97
+ else:
98
+ selected_dataset = dataset
99
+
100
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
101
+
102
+ if isdir(save_path_full):
103
+ warn("\n'{0}' already exists. Voiding Download.".format(
104
+ save_path_full))
105
+ else:
106
+ self._print('Downloading Data...')
107
+ url = "{0}/{1}".format(self.url, selected_dataset)
108
+ self._download_data(url, save_path=save_path)
109
+
110
+ return abspath(save_path_full)
util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
util/util.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+ import importlib
8
+ import argparse
9
+ from argparse import Namespace
10
+ import torchvision
11
+
12
+
13
+ def str2bool(v):
14
+ if isinstance(v, bool):
15
+ return v
16
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
17
+ return True
18
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
19
+ return False
20
+ else:
21
+ raise argparse.ArgumentTypeError('Boolean value expected.')
22
+
23
+
24
+ def copyconf(default_opt, **kwargs):
25
+ conf = Namespace(**vars(default_opt))
26
+ for key in kwargs:
27
+ setattr(conf, key, kwargs[key])
28
+ return conf
29
+
30
+
31
+ def find_class_in_module(target_cls_name, module):
32
+ target_cls_name = target_cls_name.replace('_', '').lower()
33
+ clslib = importlib.import_module(module)
34
+ cls = None
35
+ for name, clsobj in clslib.__dict__.items():
36
+ if name.lower() == target_cls_name:
37
+ cls = clsobj
38
+
39
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
40
+
41
+ return cls
42
+
43
+
44
+ def tensor2im(input_image, imtype=np.uint8):
45
+ """"Converts a Tensor array into a numpy image array.
46
+
47
+ Parameters:
48
+ input_image (tensor) -- the input image tensor array
49
+ imtype (type) -- the desired type of the converted numpy array
50
+ """
51
+ if not isinstance(input_image, np.ndarray):
52
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
53
+ image_tensor = input_image.data
54
+ else:
55
+ return input_image
56
+ image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
57
+ if image_numpy.shape[0] == 1: # grayscale to RGB
58
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
59
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
60
+ else: # if it is a numpy array, do nothing
61
+ image_numpy = input_image
62
+ return image_numpy.astype(imtype)
63
+
64
+
65
+ def diagnose_network(net, name='network'):
66
+ """Calculate and print the mean of average absolute(gradients)
67
+
68
+ Parameters:
69
+ net (torch network) -- Torch network
70
+ name (str) -- the name of the network
71
+ """
72
+ mean = 0.0
73
+ count = 0
74
+ for param in net.parameters():
75
+ if param.grad is not None:
76
+ mean += torch.mean(torch.abs(param.grad.data))
77
+ count += 1
78
+ if count > 0:
79
+ mean = mean / count
80
+ print(name)
81
+ print(mean)
82
+
83
+
84
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
85
+ """Save a numpy image to the disk
86
+
87
+ Parameters:
88
+ image_numpy (numpy array) -- input numpy array
89
+ image_path (str) -- the path of the image
90
+ """
91
+
92
+ image_pil = Image.fromarray(image_numpy)
93
+ h, w, _ = image_numpy.shape
94
+
95
+ if aspect_ratio is None:
96
+ pass
97
+ elif aspect_ratio > 1.0:
98
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
99
+ elif aspect_ratio < 1.0:
100
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
101
+ image_pil.save(image_path)
102
+
103
+
104
+ def print_numpy(x, val=True, shp=False):
105
+ """Print the mean, min, max, median, std, and size of a numpy array
106
+
107
+ Parameters:
108
+ val (bool) -- if print the values of the numpy array
109
+ shp (bool) -- if print the shape of the numpy array
110
+ """
111
+ x = x.astype(np.float64)
112
+ if shp:
113
+ print('shape,', x.shape)
114
+ if val:
115
+ x = x.flatten()
116
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
117
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
118
+
119
+
120
+ def mkdirs(paths):
121
+ """create empty directories if they don't exist
122
+
123
+ Parameters:
124
+ paths (str list) -- a list of directory paths
125
+ """
126
+ if isinstance(paths, list) and not isinstance(paths, str):
127
+ for path in paths:
128
+ mkdir(path)
129
+ else:
130
+ mkdir(paths)
131
+
132
+
133
+ def mkdir(path):
134
+ """create a single empty directory if it didn't exist
135
+
136
+ Parameters:
137
+ path (str) -- a single directory path
138
+ """
139
+ if not os.path.exists(path):
140
+ os.makedirs(path)
141
+
142
+
143
+ def correct_resize_label(t, size):
144
+ device = t.device
145
+ t = t.detach().cpu()
146
+ resized = []
147
+ for i in range(t.size(0)):
148
+ one_t = t[i, :1]
149
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
150
+ one_np = one_np[:, :, 0]
151
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
152
+ resized_t = torch.from_numpy(np.array(one_image)).long()
153
+ resized.append(resized_t)
154
+ return torch.stack(resized, dim=0).to(device)
155
+
156
+
157
+ def correct_resize(t, size, mode=Image.BICUBIC):
158
+ device = t.device
159
+ t = t.detach().cpu()
160
+ resized = []
161
+ for i in range(t.size(0)):
162
+ one_t = t[i:i + 1]
163
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
164
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
165
+ resized.append(resized_t)
166
+ return torch.stack(resized, dim=0).to(device)
util/visualizer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+
9
+ if sys.version_info[0] == 2:
10
+ VisdomExceptionBase = Exception
11
+ else:
12
+ VisdomExceptionBase = ConnectionError
13
+
14
+
15
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
16
+ """Save images to the disk.
17
+
18
+ Parameters:
19
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
20
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
21
+ image_path (str) -- the string is used to create image paths
22
+ aspect_ratio (float) -- the aspect ratio of saved images
23
+ width (int) -- the images will be resized to width x width
24
+
25
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
26
+ """
27
+ image_dir = webpage.get_image_dir()
28
+ short_path = ntpath.basename(image_path[0])
29
+ name = os.path.splitext(short_path)[0]
30
+
31
+ webpage.add_header(name)
32
+ ims, txts, links = [], [], []
33
+
34
+ for label, im_data in visuals.items():
35
+ im = util.tensor2im(im_data)
36
+ image_name = '%s/%s.png' % (label, name)
37
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
38
+ save_path = os.path.join(image_dir, image_name)
39
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
40
+ ims.append(image_name)
41
+ txts.append(label)
42
+ links.append(image_name)
43
+ webpage.add_images(ims, txts, links, width=width)
44
+
45
+
46
+ class Visualizer():
47
+ """This class includes several functions that can display/save images and print/save logging information.
48
+
49
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
50
+ """
51
+
52
+ def __init__(self, opt):
53
+ """Initialize the Visualizer class
54
+
55
+ Parameters:
56
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
57
+ Step 1: Cache the training/test options
58
+ Step 2: connect to a visdom server
59
+ Step 3: create an HTML object for saveing HTML filters
60
+ Step 4: create a logging file to store training losses
61
+ """
62
+ self.opt = opt # cache the option
63
+ if opt.display_id is None:
64
+ self.display_id = np.random.randint(100000) * 10 # just a random display id
65
+ else:
66
+ self.display_id = opt.display_id
67
+ self.use_html = opt.isTrain and not opt.no_html
68
+ self.win_size = opt.display_winsize
69
+ self.name = opt.name
70
+ self.port = opt.display_port
71
+ self.saved = False
72
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
73
+ import visdom
74
+ self.plot_data = {}
75
+ self.ncols = opt.display_ncols
76
+ if "tensorboard_base_url" not in os.environ:
77
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
78
+ else:
79
+ self.vis = visdom.Visdom(port=2004,
80
+ base_url=os.environ['tensorboard_base_url'] + '/visdom')
81
+ if not self.vis.check_connection():
82
+ self.create_visdom_connections()
83
+
84
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
85
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
86
+ self.img_dir = os.path.join(self.web_dir, 'images')
87
+ print('create web directory %s...' % self.web_dir)
88
+ util.mkdirs([self.web_dir, self.img_dir])
89
+ # create a logging file to store training losses
90
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
91
+ with open(self.log_name, "a") as log_file:
92
+ now = time.strftime("%c")
93
+ log_file.write('================ Training Loss (%s) ================\n' % now)
94
+
95
+ def reset(self):
96
+ """Reset the self.saved status"""
97
+ self.saved = False
98
+
99
+ def create_visdom_connections(self):
100
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
101
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
102
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
103
+ print('Command: %s' % cmd)
104
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
105
+
106
+ def display_current_results(self, visuals, epoch, save_result):
107
+ """Display current results on visdom; save current results to an HTML file.
108
+
109
+ Parameters:
110
+ visuals (OrderedDict) - - dictionary of images to display or save
111
+ epoch (int) - - the current epoch
112
+ save_result (bool) - - if save the current results to an HTML file
113
+ """
114
+ if self.display_id > 0: # show images in the browser using visdom
115
+ ncols = self.ncols
116
+ if ncols > 0: # show all the images in one visdom panel
117
+ ncols = min(ncols, len(visuals))
118
+ h, w = next(iter(visuals.values())).shape[:2]
119
+ table_css = """<style>
120
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
121
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
122
+ </style>""" % (w, h) # create a table css
123
+ # create a table of images.
124
+ title = self.name
125
+ label_html = ''
126
+ label_html_row = ''
127
+ images = []
128
+ idx = 0
129
+ for label, image in visuals.items():
130
+ image_numpy = util.tensor2im(image)
131
+ label_html_row += '<td>%s</td>' % label
132
+ images.append(image_numpy.transpose([2, 0, 1]))
133
+ idx += 1
134
+ if idx % ncols == 0:
135
+ label_html += '<tr>%s</tr>' % label_html_row
136
+ label_html_row = ''
137
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
138
+ while idx % ncols != 0:
139
+ images.append(white_image)
140
+ label_html_row += '<td></td>'
141
+ idx += 1
142
+ if label_html_row != '':
143
+ label_html += '<tr>%s</tr>' % label_html_row
144
+ try:
145
+ self.vis.images(images, ncols, 2, self.display_id + 1,
146
+ None, dict(title=title + ' images'))
147
+ label_html = '<table>%s</table>' % label_html
148
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
149
+ opts=dict(title=title + ' labels'))
150
+ except VisdomExceptionBase:
151
+ self.create_visdom_connections()
152
+
153
+ else: # show each image in a separate visdom panel;
154
+ idx = 1
155
+ try:
156
+ for label, image in visuals.items():
157
+ image_numpy = util.tensor2im(image)
158
+ self.vis.image(
159
+ image_numpy.transpose([2, 0, 1]),
160
+ self.display_id + idx,
161
+ None,
162
+ dict(title=label)
163
+ )
164
+ idx += 1
165
+ except VisdomExceptionBase:
166
+ self.create_visdom_connections()
167
+
168
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
169
+ self.saved = True
170
+ # save images to the disk
171
+ for label, image in visuals.items():
172
+ image_numpy = util.tensor2im(image)
173
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
174
+ util.save_image(image_numpy, img_path)
175
+
176
+ # update website
177
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
178
+ for n in range(epoch, 0, -1):
179
+ webpage.add_header('epoch [%d]' % n)
180
+ ims, txts, links = [], [], []
181
+
182
+ for label, image_numpy in visuals.items():
183
+ image_numpy = util.tensor2im(image)
184
+ img_path = 'epoch%.3d_%s.png' % (n, label)
185
+ ims.append(img_path)
186
+ txts.append(label)
187
+ links.append(img_path)
188
+ webpage.add_images(ims, txts, links, width=self.win_size)
189
+ webpage.save()
190
+
191
+ def plot_current_losses(self, epoch, counter_ratio, losses):
192
+ """display the current losses on visdom display: dictionary of error labels and values
193
+
194
+ Parameters:
195
+ epoch (int) -- current epoch
196
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
197
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
198
+ """
199
+ if len(losses) == 0:
200
+ return
201
+
202
+ plot_name = '_'.join(list(losses.keys()))
203
+
204
+ if plot_name not in self.plot_data:
205
+ self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
206
+
207
+ plot_data = self.plot_data[plot_name]
208
+ plot_id = list(self.plot_data.keys()).index(plot_name)
209
+
210
+ plot_data['X'].append(epoch + counter_ratio)
211
+ plot_data['Y'].append([losses[k] for k in plot_data['legend']])
212
+ try:
213
+ self.vis.line(
214
+ X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
215
+ Y=np.array(plot_data['Y']),
216
+ opts={
217
+ 'title': self.name,
218
+ 'legend': plot_data['legend'],
219
+ 'xlabel': 'epoch',
220
+ 'ylabel': 'loss'},
221
+ win=self.display_id - plot_id)
222
+ except VisdomExceptionBase:
223
+ self.create_visdom_connections()
224
+
225
+ # losses: same format as |losses| of plot_current_losses
226
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
227
+ """print current losses on console; also save the losses to the disk
228
+
229
+ Parameters:
230
+ epoch (int) -- current epoch
231
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
232
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
233
+ t_comp (float) -- computational time per data point (normalized by batch_size)
234
+ t_data (float) -- data loading time per data point (normalized by batch_size)
235
+ """
236
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
237
+ for k, v in losses.items():
238
+ message += '%s: %.3f ' % (k, v)
239
+
240
+ print(message) # print the message
241
+ with open(self.log_name, "a") as log_file:
242
+ log_file.write('%s\n' % message) # save the message