Spaces:
Runtime error
Runtime error
first commit
Browse files- util/__init__.py +2 -0
- util/get_data.py +110 -0
- util/html.py +86 -0
- util/image_pool.py +54 -0
- util/util.py +166 -0
- util/visualizer.py +242 -0
util/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
| 2 |
+
from util import *
|
util/get_data.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import tarfile
|
| 4 |
+
import requests
|
| 5 |
+
from warnings import warn
|
| 6 |
+
from zipfile import ZipFile
|
| 7 |
+
from bs4 import BeautifulSoup
|
| 8 |
+
from os.path import abspath, isdir, join, basename
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GetData(object):
|
| 12 |
+
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
| 13 |
+
|
| 14 |
+
Parameters:
|
| 15 |
+
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
| 16 |
+
verbose (bool) -- If True, print additional information.
|
| 17 |
+
|
| 18 |
+
Examples:
|
| 19 |
+
>>> from util.get_data import GetData
|
| 20 |
+
>>> gd = GetData(technique='cyclegan')
|
| 21 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
| 22 |
+
|
| 23 |
+
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
| 24 |
+
and 'scripts/download_cyclegan_model.sh'.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
| 28 |
+
url_dict = {
|
| 29 |
+
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
| 30 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
| 31 |
+
}
|
| 32 |
+
self.url = url_dict.get(technique.lower())
|
| 33 |
+
self._verbose = verbose
|
| 34 |
+
|
| 35 |
+
def _print(self, text):
|
| 36 |
+
if self._verbose:
|
| 37 |
+
print(text)
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def _get_options(r):
|
| 41 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
| 42 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
| 43 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
| 44 |
+
return options
|
| 45 |
+
|
| 46 |
+
def _present_options(self):
|
| 47 |
+
r = requests.get(self.url)
|
| 48 |
+
options = self._get_options(r)
|
| 49 |
+
print('Options:\n')
|
| 50 |
+
for i, o in enumerate(options):
|
| 51 |
+
print("{0}: {1}".format(i, o))
|
| 52 |
+
choice = input("\nPlease enter the number of the "
|
| 53 |
+
"dataset above you wish to download:")
|
| 54 |
+
return options[int(choice)]
|
| 55 |
+
|
| 56 |
+
def _download_data(self, dataset_url, save_path):
|
| 57 |
+
if not isdir(save_path):
|
| 58 |
+
os.makedirs(save_path)
|
| 59 |
+
|
| 60 |
+
base = basename(dataset_url)
|
| 61 |
+
temp_save_path = join(save_path, base)
|
| 62 |
+
|
| 63 |
+
with open(temp_save_path, "wb") as f:
|
| 64 |
+
r = requests.get(dataset_url)
|
| 65 |
+
f.write(r.content)
|
| 66 |
+
|
| 67 |
+
if base.endswith('.tar.gz'):
|
| 68 |
+
obj = tarfile.open(temp_save_path)
|
| 69 |
+
elif base.endswith('.zip'):
|
| 70 |
+
obj = ZipFile(temp_save_path, 'r')
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
| 73 |
+
|
| 74 |
+
self._print("Unpacking Data...")
|
| 75 |
+
obj.extractall(save_path)
|
| 76 |
+
obj.close()
|
| 77 |
+
os.remove(temp_save_path)
|
| 78 |
+
|
| 79 |
+
def get(self, save_path, dataset=None):
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
Download a dataset.
|
| 83 |
+
|
| 84 |
+
Parameters:
|
| 85 |
+
save_path (str) -- A directory to save the data to.
|
| 86 |
+
dataset (str) -- (optional). A specific dataset to download.
|
| 87 |
+
Note: this must include the file extension.
|
| 88 |
+
If None, options will be presented for you
|
| 89 |
+
to choose from.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
save_path_full (str) -- the absolute path to the downloaded data.
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
if dataset is None:
|
| 96 |
+
selected_dataset = self._present_options()
|
| 97 |
+
else:
|
| 98 |
+
selected_dataset = dataset
|
| 99 |
+
|
| 100 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
| 101 |
+
|
| 102 |
+
if isdir(save_path_full):
|
| 103 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
| 104 |
+
save_path_full))
|
| 105 |
+
else:
|
| 106 |
+
self._print('Downloading Data...')
|
| 107 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
| 108 |
+
self._download_data(url, save_path=save_path)
|
| 109 |
+
|
| 110 |
+
return abspath(save_path_full)
|
util/html.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dominate
|
| 2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HTML:
|
| 7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
| 8 |
+
|
| 9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
| 10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
| 11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, web_dir, title, refresh=0):
|
| 15 |
+
"""Initialize the HTML classes
|
| 16 |
+
|
| 17 |
+
Parameters:
|
| 18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
| 19 |
+
title (str) -- the webpage name
|
| 20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
| 21 |
+
"""
|
| 22 |
+
self.title = title
|
| 23 |
+
self.web_dir = web_dir
|
| 24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
| 25 |
+
if not os.path.exists(self.web_dir):
|
| 26 |
+
os.makedirs(self.web_dir)
|
| 27 |
+
if not os.path.exists(self.img_dir):
|
| 28 |
+
os.makedirs(self.img_dir)
|
| 29 |
+
|
| 30 |
+
self.doc = dominate.document(title=title)
|
| 31 |
+
if refresh > 0:
|
| 32 |
+
with self.doc.head:
|
| 33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
| 34 |
+
|
| 35 |
+
def get_image_dir(self):
|
| 36 |
+
"""Return the directory that stores images"""
|
| 37 |
+
return self.img_dir
|
| 38 |
+
|
| 39 |
+
def add_header(self, text):
|
| 40 |
+
"""Insert a header to the HTML file
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
text (str) -- the header text
|
| 44 |
+
"""
|
| 45 |
+
with self.doc:
|
| 46 |
+
h3(text)
|
| 47 |
+
|
| 48 |
+
def add_images(self, ims, txts, links, width=400):
|
| 49 |
+
"""add images to the HTML file
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
ims (str list) -- a list of image paths
|
| 53 |
+
txts (str list) -- a list of image names shown on the website
|
| 54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
| 55 |
+
"""
|
| 56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
| 57 |
+
self.doc.add(self.t)
|
| 58 |
+
with self.t:
|
| 59 |
+
with tr():
|
| 60 |
+
for im, txt, link in zip(ims, txts, links):
|
| 61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
| 62 |
+
with p():
|
| 63 |
+
with a(href=os.path.join('images', link)):
|
| 64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
| 65 |
+
br()
|
| 66 |
+
p(txt)
|
| 67 |
+
|
| 68 |
+
def save(self):
|
| 69 |
+
"""save the current content to the HMTL file"""
|
| 70 |
+
html_file = '%s/index.html' % self.web_dir
|
| 71 |
+
f = open(html_file, 'wt')
|
| 72 |
+
f.write(self.doc.render())
|
| 73 |
+
f.close()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__': # we show an example usage here.
|
| 77 |
+
html = HTML('web/', 'test_html')
|
| 78 |
+
html.add_header('hello world')
|
| 79 |
+
|
| 80 |
+
ims, txts, links = [], [], []
|
| 81 |
+
for n in range(4):
|
| 82 |
+
ims.append('image_%d.png' % n)
|
| 83 |
+
txts.append('text_%d' % n)
|
| 84 |
+
links.append('image_%d.png' % n)
|
| 85 |
+
html.add_images(ims, txts, links)
|
| 86 |
+
html.save()
|
util/image_pool.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ImagePool():
|
| 6 |
+
"""This class implements an image buffer that stores previously generated images.
|
| 7 |
+
|
| 8 |
+
This buffer enables us to update discriminators using a history of generated images
|
| 9 |
+
rather than the ones produced by the latest generators.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, pool_size):
|
| 13 |
+
"""Initialize the ImagePool class
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
| 17 |
+
"""
|
| 18 |
+
self.pool_size = pool_size
|
| 19 |
+
if self.pool_size > 0: # create an empty pool
|
| 20 |
+
self.num_imgs = 0
|
| 21 |
+
self.images = []
|
| 22 |
+
|
| 23 |
+
def query(self, images):
|
| 24 |
+
"""Return an image from the pool.
|
| 25 |
+
|
| 26 |
+
Parameters:
|
| 27 |
+
images: the latest generated images from the generator
|
| 28 |
+
|
| 29 |
+
Returns images from the buffer.
|
| 30 |
+
|
| 31 |
+
By 50/100, the buffer will return input images.
|
| 32 |
+
By 50/100, the buffer will return images previously stored in the buffer,
|
| 33 |
+
and insert the current images to the buffer.
|
| 34 |
+
"""
|
| 35 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
| 36 |
+
return images
|
| 37 |
+
return_images = []
|
| 38 |
+
for image in images:
|
| 39 |
+
image = torch.unsqueeze(image.data, 0)
|
| 40 |
+
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
| 41 |
+
self.num_imgs = self.num_imgs + 1
|
| 42 |
+
self.images.append(image)
|
| 43 |
+
return_images.append(image)
|
| 44 |
+
else:
|
| 45 |
+
p = random.uniform(0, 1)
|
| 46 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
| 47 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
| 48 |
+
tmp = self.images[random_id].clone()
|
| 49 |
+
self.images[random_id] = image
|
| 50 |
+
return_images.append(tmp)
|
| 51 |
+
else: # by another 50% chance, the buffer will return the current image
|
| 52 |
+
return_images.append(image)
|
| 53 |
+
return_images = torch.cat(return_images, 0) # collect all the images and return
|
| 54 |
+
return return_images
|
util/util.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains simple helper functions """
|
| 2 |
+
from __future__ import print_function
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
import importlib
|
| 8 |
+
import argparse
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
import torchvision
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def str2bool(v):
|
| 14 |
+
if isinstance(v, bool):
|
| 15 |
+
return v
|
| 16 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 17 |
+
return True
|
| 18 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 19 |
+
return False
|
| 20 |
+
else:
|
| 21 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def copyconf(default_opt, **kwargs):
|
| 25 |
+
conf = Namespace(**vars(default_opt))
|
| 26 |
+
for key in kwargs:
|
| 27 |
+
setattr(conf, key, kwargs[key])
|
| 28 |
+
return conf
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def find_class_in_module(target_cls_name, module):
|
| 32 |
+
target_cls_name = target_cls_name.replace('_', '').lower()
|
| 33 |
+
clslib = importlib.import_module(module)
|
| 34 |
+
cls = None
|
| 35 |
+
for name, clsobj in clslib.__dict__.items():
|
| 36 |
+
if name.lower() == target_cls_name:
|
| 37 |
+
cls = clsobj
|
| 38 |
+
|
| 39 |
+
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
|
| 40 |
+
|
| 41 |
+
return cls
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def tensor2im(input_image, imtype=np.uint8):
|
| 45 |
+
""""Converts a Tensor array into a numpy image array.
|
| 46 |
+
|
| 47 |
+
Parameters:
|
| 48 |
+
input_image (tensor) -- the input image tensor array
|
| 49 |
+
imtype (type) -- the desired type of the converted numpy array
|
| 50 |
+
"""
|
| 51 |
+
if not isinstance(input_image, np.ndarray):
|
| 52 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
| 53 |
+
image_tensor = input_image.data
|
| 54 |
+
else:
|
| 55 |
+
return input_image
|
| 56 |
+
image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
|
| 57 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
| 58 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
| 59 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
| 60 |
+
else: # if it is a numpy array, do nothing
|
| 61 |
+
image_numpy = input_image
|
| 62 |
+
return image_numpy.astype(imtype)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def diagnose_network(net, name='network'):
|
| 66 |
+
"""Calculate and print the mean of average absolute(gradients)
|
| 67 |
+
|
| 68 |
+
Parameters:
|
| 69 |
+
net (torch network) -- Torch network
|
| 70 |
+
name (str) -- the name of the network
|
| 71 |
+
"""
|
| 72 |
+
mean = 0.0
|
| 73 |
+
count = 0
|
| 74 |
+
for param in net.parameters():
|
| 75 |
+
if param.grad is not None:
|
| 76 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
| 77 |
+
count += 1
|
| 78 |
+
if count > 0:
|
| 79 |
+
mean = mean / count
|
| 80 |
+
print(name)
|
| 81 |
+
print(mean)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
| 85 |
+
"""Save a numpy image to the disk
|
| 86 |
+
|
| 87 |
+
Parameters:
|
| 88 |
+
image_numpy (numpy array) -- input numpy array
|
| 89 |
+
image_path (str) -- the path of the image
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
image_pil = Image.fromarray(image_numpy)
|
| 93 |
+
h, w, _ = image_numpy.shape
|
| 94 |
+
|
| 95 |
+
if aspect_ratio is None:
|
| 96 |
+
pass
|
| 97 |
+
elif aspect_ratio > 1.0:
|
| 98 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
| 99 |
+
elif aspect_ratio < 1.0:
|
| 100 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
| 101 |
+
image_pil.save(image_path)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def print_numpy(x, val=True, shp=False):
|
| 105 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
| 106 |
+
|
| 107 |
+
Parameters:
|
| 108 |
+
val (bool) -- if print the values of the numpy array
|
| 109 |
+
shp (bool) -- if print the shape of the numpy array
|
| 110 |
+
"""
|
| 111 |
+
x = x.astype(np.float64)
|
| 112 |
+
if shp:
|
| 113 |
+
print('shape,', x.shape)
|
| 114 |
+
if val:
|
| 115 |
+
x = x.flatten()
|
| 116 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
| 117 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def mkdirs(paths):
|
| 121 |
+
"""create empty directories if they don't exist
|
| 122 |
+
|
| 123 |
+
Parameters:
|
| 124 |
+
paths (str list) -- a list of directory paths
|
| 125 |
+
"""
|
| 126 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
| 127 |
+
for path in paths:
|
| 128 |
+
mkdir(path)
|
| 129 |
+
else:
|
| 130 |
+
mkdir(paths)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def mkdir(path):
|
| 134 |
+
"""create a single empty directory if it didn't exist
|
| 135 |
+
|
| 136 |
+
Parameters:
|
| 137 |
+
path (str) -- a single directory path
|
| 138 |
+
"""
|
| 139 |
+
if not os.path.exists(path):
|
| 140 |
+
os.makedirs(path)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def correct_resize_label(t, size):
|
| 144 |
+
device = t.device
|
| 145 |
+
t = t.detach().cpu()
|
| 146 |
+
resized = []
|
| 147 |
+
for i in range(t.size(0)):
|
| 148 |
+
one_t = t[i, :1]
|
| 149 |
+
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
|
| 150 |
+
one_np = one_np[:, :, 0]
|
| 151 |
+
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
|
| 152 |
+
resized_t = torch.from_numpy(np.array(one_image)).long()
|
| 153 |
+
resized.append(resized_t)
|
| 154 |
+
return torch.stack(resized, dim=0).to(device)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def correct_resize(t, size, mode=Image.BICUBIC):
|
| 158 |
+
device = t.device
|
| 159 |
+
t = t.detach().cpu()
|
| 160 |
+
resized = []
|
| 161 |
+
for i in range(t.size(0)):
|
| 162 |
+
one_t = t[i:i + 1]
|
| 163 |
+
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
|
| 164 |
+
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
|
| 165 |
+
resized.append(resized_t)
|
| 166 |
+
return torch.stack(resized, dim=0).to(device)
|
util/visualizer.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import ntpath
|
| 5 |
+
import time
|
| 6 |
+
from . import util, html
|
| 7 |
+
from subprocess import Popen, PIPE
|
| 8 |
+
|
| 9 |
+
if sys.version_info[0] == 2:
|
| 10 |
+
VisdomExceptionBase = Exception
|
| 11 |
+
else:
|
| 12 |
+
VisdomExceptionBase = ConnectionError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
| 16 |
+
"""Save images to the disk.
|
| 17 |
+
|
| 18 |
+
Parameters:
|
| 19 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
| 20 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
| 21 |
+
image_path (str) -- the string is used to create image paths
|
| 22 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
| 23 |
+
width (int) -- the images will be resized to width x width
|
| 24 |
+
|
| 25 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
| 26 |
+
"""
|
| 27 |
+
image_dir = webpage.get_image_dir()
|
| 28 |
+
short_path = ntpath.basename(image_path[0])
|
| 29 |
+
name = os.path.splitext(short_path)[0]
|
| 30 |
+
|
| 31 |
+
webpage.add_header(name)
|
| 32 |
+
ims, txts, links = [], [], []
|
| 33 |
+
|
| 34 |
+
for label, im_data in visuals.items():
|
| 35 |
+
im = util.tensor2im(im_data)
|
| 36 |
+
image_name = '%s/%s.png' % (label, name)
|
| 37 |
+
os.makedirs(os.path.join(image_dir, label), exist_ok=True)
|
| 38 |
+
save_path = os.path.join(image_dir, image_name)
|
| 39 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
| 40 |
+
ims.append(image_name)
|
| 41 |
+
txts.append(label)
|
| 42 |
+
links.append(image_name)
|
| 43 |
+
webpage.add_images(ims, txts, links, width=width)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Visualizer():
|
| 47 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
| 48 |
+
|
| 49 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, opt):
|
| 53 |
+
"""Initialize the Visualizer class
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 57 |
+
Step 1: Cache the training/test options
|
| 58 |
+
Step 2: connect to a visdom server
|
| 59 |
+
Step 3: create an HTML object for saveing HTML filters
|
| 60 |
+
Step 4: create a logging file to store training losses
|
| 61 |
+
"""
|
| 62 |
+
self.opt = opt # cache the option
|
| 63 |
+
if opt.display_id is None:
|
| 64 |
+
self.display_id = np.random.randint(100000) * 10 # just a random display id
|
| 65 |
+
else:
|
| 66 |
+
self.display_id = opt.display_id
|
| 67 |
+
self.use_html = opt.isTrain and not opt.no_html
|
| 68 |
+
self.win_size = opt.display_winsize
|
| 69 |
+
self.name = opt.name
|
| 70 |
+
self.port = opt.display_port
|
| 71 |
+
self.saved = False
|
| 72 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
| 73 |
+
import visdom
|
| 74 |
+
self.plot_data = {}
|
| 75 |
+
self.ncols = opt.display_ncols
|
| 76 |
+
if "tensorboard_base_url" not in os.environ:
|
| 77 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
| 78 |
+
else:
|
| 79 |
+
self.vis = visdom.Visdom(port=2004,
|
| 80 |
+
base_url=os.environ['tensorboard_base_url'] + '/visdom')
|
| 81 |
+
if not self.vis.check_connection():
|
| 82 |
+
self.create_visdom_connections()
|
| 83 |
+
|
| 84 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
| 85 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
| 86 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
| 87 |
+
print('create web directory %s...' % self.web_dir)
|
| 88 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
| 89 |
+
# create a logging file to store training losses
|
| 90 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
| 91 |
+
with open(self.log_name, "a") as log_file:
|
| 92 |
+
now = time.strftime("%c")
|
| 93 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
| 94 |
+
|
| 95 |
+
def reset(self):
|
| 96 |
+
"""Reset the self.saved status"""
|
| 97 |
+
self.saved = False
|
| 98 |
+
|
| 99 |
+
def create_visdom_connections(self):
|
| 100 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
| 101 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
| 102 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
| 103 |
+
print('Command: %s' % cmd)
|
| 104 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
| 105 |
+
|
| 106 |
+
def display_current_results(self, visuals, epoch, save_result):
|
| 107 |
+
"""Display current results on visdom; save current results to an HTML file.
|
| 108 |
+
|
| 109 |
+
Parameters:
|
| 110 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
| 111 |
+
epoch (int) - - the current epoch
|
| 112 |
+
save_result (bool) - - if save the current results to an HTML file
|
| 113 |
+
"""
|
| 114 |
+
if self.display_id > 0: # show images in the browser using visdom
|
| 115 |
+
ncols = self.ncols
|
| 116 |
+
if ncols > 0: # show all the images in one visdom panel
|
| 117 |
+
ncols = min(ncols, len(visuals))
|
| 118 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
| 119 |
+
table_css = """<style>
|
| 120 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
| 121 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
| 122 |
+
</style>""" % (w, h) # create a table css
|
| 123 |
+
# create a table of images.
|
| 124 |
+
title = self.name
|
| 125 |
+
label_html = ''
|
| 126 |
+
label_html_row = ''
|
| 127 |
+
images = []
|
| 128 |
+
idx = 0
|
| 129 |
+
for label, image in visuals.items():
|
| 130 |
+
image_numpy = util.tensor2im(image)
|
| 131 |
+
label_html_row += '<td>%s</td>' % label
|
| 132 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
| 133 |
+
idx += 1
|
| 134 |
+
if idx % ncols == 0:
|
| 135 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
| 136 |
+
label_html_row = ''
|
| 137 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
| 138 |
+
while idx % ncols != 0:
|
| 139 |
+
images.append(white_image)
|
| 140 |
+
label_html_row += '<td></td>'
|
| 141 |
+
idx += 1
|
| 142 |
+
if label_html_row != '':
|
| 143 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
| 144 |
+
try:
|
| 145 |
+
self.vis.images(images, ncols, 2, self.display_id + 1,
|
| 146 |
+
None, dict(title=title + ' images'))
|
| 147 |
+
label_html = '<table>%s</table>' % label_html
|
| 148 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
| 149 |
+
opts=dict(title=title + ' labels'))
|
| 150 |
+
except VisdomExceptionBase:
|
| 151 |
+
self.create_visdom_connections()
|
| 152 |
+
|
| 153 |
+
else: # show each image in a separate visdom panel;
|
| 154 |
+
idx = 1
|
| 155 |
+
try:
|
| 156 |
+
for label, image in visuals.items():
|
| 157 |
+
image_numpy = util.tensor2im(image)
|
| 158 |
+
self.vis.image(
|
| 159 |
+
image_numpy.transpose([2, 0, 1]),
|
| 160 |
+
self.display_id + idx,
|
| 161 |
+
None,
|
| 162 |
+
dict(title=label)
|
| 163 |
+
)
|
| 164 |
+
idx += 1
|
| 165 |
+
except VisdomExceptionBase:
|
| 166 |
+
self.create_visdom_connections()
|
| 167 |
+
|
| 168 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
| 169 |
+
self.saved = True
|
| 170 |
+
# save images to the disk
|
| 171 |
+
for label, image in visuals.items():
|
| 172 |
+
image_numpy = util.tensor2im(image)
|
| 173 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
| 174 |
+
util.save_image(image_numpy, img_path)
|
| 175 |
+
|
| 176 |
+
# update website
|
| 177 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
|
| 178 |
+
for n in range(epoch, 0, -1):
|
| 179 |
+
webpage.add_header('epoch [%d]' % n)
|
| 180 |
+
ims, txts, links = [], [], []
|
| 181 |
+
|
| 182 |
+
for label, image_numpy in visuals.items():
|
| 183 |
+
image_numpy = util.tensor2im(image)
|
| 184 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
| 185 |
+
ims.append(img_path)
|
| 186 |
+
txts.append(label)
|
| 187 |
+
links.append(img_path)
|
| 188 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
| 189 |
+
webpage.save()
|
| 190 |
+
|
| 191 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
| 192 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
| 193 |
+
|
| 194 |
+
Parameters:
|
| 195 |
+
epoch (int) -- current epoch
|
| 196 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
| 197 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
| 198 |
+
"""
|
| 199 |
+
if len(losses) == 0:
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
plot_name = '_'.join(list(losses.keys()))
|
| 203 |
+
|
| 204 |
+
if plot_name not in self.plot_data:
|
| 205 |
+
self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
| 206 |
+
|
| 207 |
+
plot_data = self.plot_data[plot_name]
|
| 208 |
+
plot_id = list(self.plot_data.keys()).index(plot_name)
|
| 209 |
+
|
| 210 |
+
plot_data['X'].append(epoch + counter_ratio)
|
| 211 |
+
plot_data['Y'].append([losses[k] for k in plot_data['legend']])
|
| 212 |
+
try:
|
| 213 |
+
self.vis.line(
|
| 214 |
+
X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
|
| 215 |
+
Y=np.array(plot_data['Y']),
|
| 216 |
+
opts={
|
| 217 |
+
'title': self.name,
|
| 218 |
+
'legend': plot_data['legend'],
|
| 219 |
+
'xlabel': 'epoch',
|
| 220 |
+
'ylabel': 'loss'},
|
| 221 |
+
win=self.display_id - plot_id)
|
| 222 |
+
except VisdomExceptionBase:
|
| 223 |
+
self.create_visdom_connections()
|
| 224 |
+
|
| 225 |
+
# losses: same format as |losses| of plot_current_losses
|
| 226 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
| 227 |
+
"""print current losses on console; also save the losses to the disk
|
| 228 |
+
|
| 229 |
+
Parameters:
|
| 230 |
+
epoch (int) -- current epoch
|
| 231 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
| 232 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
| 233 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
| 234 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
| 235 |
+
"""
|
| 236 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
| 237 |
+
for k, v in losses.items():
|
| 238 |
+
message += '%s: %.3f ' % (k, v)
|
| 239 |
+
|
| 240 |
+
print(message) # print the message
|
| 241 |
+
with open(self.log_name, "a") as log_file:
|
| 242 |
+
log_file.write('%s\n' % message) # save the message
|