Spaces:
Runtime error
Runtime error
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons Attribution-NonCommercial | |
| # 4.0 International License. To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
| # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
| import os | |
| import sys | |
| import glob | |
| import datetime | |
| import pickle | |
| import re | |
| import numpy as np | |
| from collections import OrderedDict | |
| import scipy.ndimage | |
| import PIL.Image | |
| import config | |
| import dataset | |
| import legacy | |
| #---------------------------------------------------------------------------- | |
| # Convenience wrappers for pickle that are able to load data produced by | |
| # older versions of the code. | |
| def load_pkl(filename): | |
| with open(filename, 'rb') as file: | |
| return legacy.LegacyUnpickler(file, encoding='latin1').load() | |
| def save_pkl(obj, filename): | |
| with open(filename, 'wb') as file: | |
| pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) | |
| #---------------------------------------------------------------------------- | |
| # Image utils. | |
| def adjust_dynamic_range(data, drange_in, drange_out): | |
| if drange_in != drange_out: | |
| scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) | |
| bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) | |
| data = data * scale + bias | |
| return data | |
| def create_image_grid(images, grid_size=None): | |
| assert images.ndim == 3 or images.ndim == 4 | |
| num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] | |
| if grid_size is not None: | |
| grid_w, grid_h = tuple(grid_size) | |
| else: | |
| grid_w = max(int(np.ceil(np.sqrt(num))), 1) | |
| grid_h = max((num - 1) // grid_w + 1, 1) | |
| grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) | |
| for idx in range(num): | |
| x = (idx % grid_w) * img_w | |
| y = (idx // grid_w) * img_h | |
| grid[..., y : y + img_h, x : x + img_w] = images[idx] | |
| return grid | |
| def convert_to_pil_image(image, drange=[0,1]): | |
| assert image.ndim == 2 or image.ndim == 3 | |
| if image.ndim == 3: | |
| if image.shape[0] == 1: | |
| image = image[0] # grayscale CHW => HW | |
| else: | |
| image = image.transpose(1, 2, 0) # CHW -> HWC | |
| image = adjust_dynamic_range(image, drange, [0,255]) | |
| image = np.rint(image).clip(0, 255).astype(np.uint8) | |
| format = 'RGB' if image.ndim == 3 else 'L' | |
| return PIL.Image.fromarray(image, format) | |
| def save_image(image, filename, drange=[0,1], quality=95): | |
| img = convert_to_pil_image(image, drange) | |
| if '.jpg' in filename: | |
| img.save(filename,"JPEG", quality=quality, optimize=True) | |
| else: | |
| img.save(filename) | |
| def save_image_grid(images, filename, drange=[0,1], grid_size=None): | |
| convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) | |
| #---------------------------------------------------------------------------- | |
| # Logging of stdout and stderr to a file. | |
| class OutputLogger(object): | |
| def __init__(self): | |
| self.file = None | |
| self.buffer = '' | |
| def set_log_file(self, filename, mode='wt'): | |
| assert self.file is None | |
| self.file = open(filename, mode) | |
| if self.buffer is not None: | |
| self.file.write(self.buffer) | |
| self.buffer = None | |
| def write(self, data): | |
| if self.file is not None: | |
| self.file.write(data) | |
| if self.buffer is not None: | |
| self.buffer += data | |
| def flush(self): | |
| if self.file is not None: | |
| self.file.flush() | |
| class TeeOutputStream(object): | |
| def __init__(self, child_streams, autoflush=False): | |
| self.child_streams = child_streams | |
| self.autoflush = autoflush | |
| def write(self, data): | |
| for stream in self.child_streams: | |
| stream.write(data) | |
| if self.autoflush: | |
| self.flush() | |
| def flush(self): | |
| for stream in self.child_streams: | |
| stream.flush() | |
| output_logger = None | |
| def init_output_logging(): | |
| global output_logger | |
| if output_logger is None: | |
| output_logger = OutputLogger() | |
| sys.stdout = TeeOutputStream([sys.stdout, output_logger], autoflush=True) | |
| sys.stderr = TeeOutputStream([sys.stderr, output_logger], autoflush=True) | |
| def set_output_log_file(filename, mode='wt'): | |
| if output_logger is not None: | |
| output_logger.set_log_file(filename, mode) | |
| #---------------------------------------------------------------------------- | |
| # Reporting results. | |
| def create_result_subdir(result_dir, desc): | |
| # Select run ID and create subdir. | |
| while True: | |
| run_id = 0 | |
| for fname in glob.glob(os.path.join(result_dir, '*')): | |
| try: | |
| fbase = os.path.basename(fname) | |
| ford = int(fbase[:fbase.find('-')]) | |
| run_id = max(run_id, ford + 1) | |
| except ValueError: | |
| pass | |
| result_subdir = os.path.join(result_dir, '%03d-%s' % (run_id, desc)) | |
| try: | |
| os.makedirs(result_subdir) | |
| break | |
| except OSError: | |
| if os.path.isdir(result_subdir): | |
| continue | |
| raise | |
| print("Saving results to", result_subdir) | |
| set_output_log_file(os.path.join(result_subdir, 'log.txt')) | |
| # Export config. | |
| try: | |
| with open(os.path.join(result_subdir, 'config.txt'), 'wt') as fout: | |
| for k, v in sorted(config.__dict__.items()): | |
| if not k.startswith('_'): | |
| fout.write("%s = %s\n" % (k, str(v))) | |
| except: | |
| pass | |
| return result_subdir | |
| def format_time(seconds): | |
| s = int(np.rint(seconds)) | |
| if s < 60: return '%ds' % (s) | |
| elif s < 60*60: return '%dm %02ds' % (s // 60, s % 60) | |
| elif s < 24*60*60: return '%dh %02dm %02ds' % (s // (60*60), (s // 60) % 60, s % 60) | |
| else: return '%dd %02dh %02dm' % (s // (24*60*60), (s // (60*60)) % 24, (s // 60) % 60) | |
| #---------------------------------------------------------------------------- | |
| # Locating results. | |
| def locate_result_subdir(run_id_or_result_subdir): | |
| if isinstance(run_id_or_result_subdir, str) and os.path.isdir(run_id_or_result_subdir): | |
| return run_id_or_result_subdir | |
| searchdirs = [] | |
| searchdirs += [''] | |
| searchdirs += ['results'] | |
| searchdirs += ['networks'] | |
| for searchdir in searchdirs: | |
| dir = config.result_dir if searchdir == '' else os.path.join(config.result_dir, searchdir) | |
| dir = os.path.join(dir, str(run_id_or_result_subdir)) | |
| if os.path.isdir(dir): | |
| return dir | |
| prefix = '%03d' % run_id_or_result_subdir if isinstance(run_id_or_result_subdir, int) else str(run_id_or_result_subdir) | |
| dirs = sorted(glob.glob(os.path.join(config.result_dir, searchdir, prefix + '-*'))) | |
| dirs = [dir for dir in dirs if os.path.isdir(dir)] | |
| if len(dirs) == 1: | |
| return dirs[0] | |
| raise IOError('Cannot locate result subdir for run', run_id_or_result_subdir) | |
| def list_network_pkls(run_id_or_result_subdir, include_final=True): | |
| result_subdir = locate_result_subdir(run_id_or_result_subdir) | |
| pkls = sorted(glob.glob(os.path.join(result_subdir, 'network-*.pkl'))) | |
| if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': | |
| if include_final: | |
| pkls.append(pkls[0]) | |
| del pkls[0] | |
| return pkls | |
| def locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None): | |
| if isinstance(run_id_or_result_subdir_or_network_pkl, str) and os.path.isfile(run_id_or_result_subdir_or_network_pkl): | |
| return run_id_or_result_subdir_or_network_pkl | |
| pkls = list_network_pkls(run_id_or_result_subdir_or_network_pkl) | |
| if len(pkls) >= 1 and snapshot is None: | |
| return pkls[-1] | |
| for pkl in pkls: | |
| try: | |
| name = os.path.splitext(os.path.basename(pkl))[0] | |
| number = int(name.split('-')[-1]) | |
| if number == snapshot: | |
| return pkl | |
| except ValueError: pass | |
| except IndexError: pass | |
| raise IOError('Cannot locate network pkl for snapshot', snapshot) | |
| def get_id_string_for_network_pkl(network_pkl): | |
| p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') | |
| return '-'.join(p[max(len(p) - 2, 0):]) | |
| #---------------------------------------------------------------------------- | |
| # Loading and using trained networks. | |
| def load_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None): | |
| return load_pkl(locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot)) | |
| def random_latents(num_latents, G, random_state=None): | |
| if random_state is not None: | |
| return random_state.randn(num_latents, *G.input_shape[1:]).astype(np.float32) | |
| else: | |
| return np.random.randn(num_latents, *G.input_shape[1:]).astype(np.float32) | |
| def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment | |
| result_subdir = locate_result_subdir(run_id) | |
| # Parse config.txt. | |
| parsed_cfg = dict() | |
| with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f: | |
| for line in f: | |
| if line.startswith('dataset =') or line.startswith('train ='): | |
| exec(line, parsed_cfg, parsed_cfg) | |
| dataset_cfg = parsed_cfg.get('dataset', dict()) | |
| train_cfg = parsed_cfg.get('train', dict()) | |
| mirror_augment = train_cfg.get('mirror_augment', False) | |
| # Handle legacy options. | |
| if 'h5_path' in dataset_cfg: | |
| dataset_cfg['tfrecord_dir'] = dataset_cfg.pop('h5_path').replace('.h5', '') | |
| if 'mirror_augment' in dataset_cfg: | |
| mirror_augment = dataset_cfg.pop('mirror_augment') | |
| if 'max_labels' in dataset_cfg: | |
| v = dataset_cfg.pop('max_labels') | |
| if v is None: v = 0 | |
| if v == 'all': v = 'full' | |
| dataset_cfg['max_label_size'] = v | |
| if 'max_images' in dataset_cfg: | |
| dataset_cfg.pop('max_images') | |
| # Handle legacy dataset names. | |
| v = dataset_cfg['tfrecord_dir'] | |
| v = v.replace('-32x32', '').replace('-32', '') | |
| v = v.replace('-128x128', '').replace('-128', '') | |
| v = v.replace('-256x256', '').replace('-256', '') | |
| v = v.replace('-1024x1024', '').replace('-1024', '') | |
| v = v.replace('celeba-hq', 'celebahq') | |
| v = v.replace('cifar-10', 'cifar10') | |
| v = v.replace('cifar-100', 'cifar100') | |
| v = v.replace('mnist-rgb', 'mnistrgb') | |
| v = re.sub('lsun-100k-([^-]*)', 'lsun-\\1-100k', v) | |
| v = re.sub('lsun-full-([^-]*)', 'lsun-\\1-full', v) | |
| dataset_cfg['tfrecord_dir'] = v | |
| # Load dataset. | |
| dataset_cfg.update(kwargs) | |
| dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **dataset_cfg) | |
| return dataset_obj, mirror_augment | |
| def apply_mirror_augment(minibatch): | |
| mask = np.random.rand(minibatch.shape[0]) < 0.5 | |
| minibatch = np.array(minibatch) | |
| minibatch[mask] = minibatch[mask, :, :, ::-1] | |
| return minibatch | |
| #---------------------------------------------------------------------------- | |
| # Text labels. | |
| _text_label_cache = OrderedDict() | |
| def draw_text_label(img, text, x, y, alignx=0.5, aligny=0.5, color=255, opacity=1.0, glow_opacity=1.0, **kwargs): | |
| color = np.array(color).flatten().astype(np.float32) | |
| assert img.ndim == 3 and img.shape[2] == color.size or color.size == 1 | |
| alpha, glow = setup_text_label(text, **kwargs) | |
| xx, yy = int(np.rint(x - alpha.shape[1] * alignx)), int(np.rint(y - alpha.shape[0] * aligny)) | |
| xb, yb = max(-xx, 0), max(-yy, 0) | |
| xe, ye = min(alpha.shape[1], img.shape[1] - xx), min(alpha.shape[0], img.shape[0] - yy) | |
| img = np.array(img) | |
| slice = img[yy+yb : yy+ye, xx+xb : xx+xe, :] | |
| slice[:] = slice * (1.0 - (1.0 - (1.0 - alpha[yb:ye, xb:xe]) * (1.0 - glow[yb:ye, xb:xe] * glow_opacity)) * opacity)[:, :, np.newaxis] | |
| slice[:] = slice + alpha[yb:ye, xb:xe, np.newaxis] * (color * opacity)[np.newaxis, np.newaxis, :] | |
| return img | |
| def setup_text_label(text, font='Calibri', fontsize=32, padding=6, glow_size=2.0, glow_coef=3.0, glow_exp=2.0, cache_size=100): # => (alpha, glow) | |
| # Lookup from cache. | |
| key = (text, font, fontsize, padding, glow_size, glow_coef, glow_exp) | |
| if key in _text_label_cache: | |
| value = _text_label_cache[key] | |
| del _text_label_cache[key] # LRU policy | |
| _text_label_cache[key] = value | |
| return value | |
| # Limit cache size. | |
| while len(_text_label_cache) >= cache_size: | |
| _text_label_cache.popitem(last=False) | |
| # Render text. | |
| import moviepy.editor # pip install moviepy | |
| alpha = moviepy.editor.TextClip(text, font=font, fontsize=fontsize).mask.make_frame(0) | |
| alpha = np.pad(alpha, padding, mode='constant', constant_values=0.0) | |
| glow = scipy.ndimage.gaussian_filter(alpha, glow_size) | |
| glow = 1.0 - np.maximum(1.0 - glow * glow_coef, 0.0) ** glow_exp | |
| # Add to cache. | |
| value = (alpha, glow) | |
| _text_label_cache[key] = value | |
| return value | |
| #---------------------------------------------------------------------------- | |