Added base files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Demo.ipynb +0 -0
- align_images.py +57 -0
- config.py +22 -0
- dnnlib/__init__.py +20 -0
- dnnlib/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/__pycache__/util.cpython-36.pyc +0 -0
- dnnlib/__pycache__/util.cpython-37.pyc +0 -0
- dnnlib/submission/__init__.py +9 -0
- dnnlib/submission/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/submission/__pycache__/run_context.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/run_context.cpython-37.pyc +0 -0
- dnnlib/submission/__pycache__/submit.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/submit.cpython-37.pyc +0 -0
- dnnlib/submission/_internal/run.py +45 -0
- dnnlib/submission/run_context.py +99 -0
- dnnlib/submission/submit.py +290 -0
- dnnlib/tflib/__init__.py +16 -0
- dnnlib/tflib/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/network.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/network.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc +0 -0
- dnnlib/tflib/autosummary.py +184 -0
- dnnlib/tflib/network.py +628 -0
- dnnlib/tflib/optimizer.py +214 -0
- dnnlib/tflib/tfutil.py +242 -0
- dnnlib/util.py +408 -0
- encode_images.py +242 -0
- encoder/__init__.py +0 -0
- encoder/__pycache__/__init__.cpython-36.pyc +0 -0
- encoder/__pycache__/__init__.cpython-37.pyc +0 -0
- encoder/__pycache__/generator_model.cpython-36.pyc +0 -0
- encoder/__pycache__/generator_model.cpython-37.pyc +0 -0
- encoder/__pycache__/perceptual_model.cpython-36.pyc +0 -0
- encoder/__pycache__/perceptual_model.cpython-37.pyc +0 -0
- encoder/generator_model.py +137 -0
- encoder/perceptual_model.py +304 -0
- ffhq_dataset/__init__.py +0 -0
- ffhq_dataset/__pycache__/__init__.cpython-36.pyc +0 -0
- ffhq_dataset/__pycache__/__init__.cpython-37.pyc +0 -0
- ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc +0 -0
- ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc +0 -0
- ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc +0 -0
Demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
align_images.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import bz2
|
| 4 |
+
import argparse
|
| 5 |
+
from keras.utils import get_file
|
| 6 |
+
from ffhq_dataset.face_alignment import image_align
|
| 7 |
+
from ffhq_dataset.landmarks_detector import LandmarksDetector
|
| 8 |
+
import multiprocessing
|
| 9 |
+
|
| 10 |
+
def unpack_bz2(src_path):
|
| 11 |
+
data = bz2.BZ2File(src_path).read()
|
| 12 |
+
dst_path = src_path[:-4]
|
| 13 |
+
with open(dst_path, 'wb') as fp:
|
| 14 |
+
fp.write(data)
|
| 15 |
+
return dst_path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
"""
|
| 20 |
+
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
|
| 21 |
+
python align_images.py /raw_images /aligned_images
|
| 22 |
+
"""
|
| 23 |
+
parser = argparse.ArgumentParser(description='Align faces from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 24 |
+
parser.add_argument('raw_dir', help='Directory with raw images for face alignment')
|
| 25 |
+
parser.add_argument('aligned_dir', help='Directory for storing aligned images')
|
| 26 |
+
parser.add_argument('--output_size', default=1024, help='The dimension of images for input to the model', type=int)
|
| 27 |
+
parser.add_argument('--x_scale', default=1, help='Scaling factor for x dimension', type=float)
|
| 28 |
+
parser.add_argument('--y_scale', default=1, help='Scaling factor for y dimension', type=float)
|
| 29 |
+
parser.add_argument('--em_scale', default=0.1, help='Scaling factor for eye-mouth distance', type=float)
|
| 30 |
+
parser.add_argument('--use_alpha', default=False, help='Add an alpha channel for masking', type=bool)
|
| 31 |
+
|
| 32 |
+
args, other_args = parser.parse_known_args()
|
| 33 |
+
|
| 34 |
+
landmarks_model_path = unpack_bz2("shape_predictor_68_face_landmarks.dat.bz2")
|
| 35 |
+
RAW_IMAGES_DIR = args.raw_dir
|
| 36 |
+
ALIGNED_IMAGES_DIR = args.aligned_dir
|
| 37 |
+
|
| 38 |
+
landmarks_detector = LandmarksDetector(landmarks_model_path)
|
| 39 |
+
for img_name in os.listdir(RAW_IMAGES_DIR):
|
| 40 |
+
print('Aligning %s ...' % img_name)
|
| 41 |
+
try:
|
| 42 |
+
raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
|
| 43 |
+
fn = face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], 1)
|
| 44 |
+
if os.path.isfile(fn):
|
| 45 |
+
continue
|
| 46 |
+
print('Getting landmarks...')
|
| 47 |
+
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
|
| 48 |
+
try:
|
| 49 |
+
print('Starting face alignment...')
|
| 50 |
+
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
|
| 51 |
+
aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
|
| 52 |
+
image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=args.output_size, x_scale=args.x_scale, y_scale=args.y_scale, em_scale=args.em_scale, alpha=args.use_alpha)
|
| 53 |
+
print('Wrote result %s' % aligned_face_path)
|
| 54 |
+
except:
|
| 55 |
+
print("Exception in face alignment!")
|
| 56 |
+
except:
|
| 57 |
+
print("Exception in landmark detection!")
|
config.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Global configuration."""
|
| 9 |
+
|
| 10 |
+
#----------------------------------------------------------------------------
|
| 11 |
+
# Paths.
|
| 12 |
+
|
| 13 |
+
result_dir = 'results'
|
| 14 |
+
data_dir = 'datasets'
|
| 15 |
+
cache_dir = 'cache'
|
| 16 |
+
run_dir_ignore = ['results', 'datasets', 'cache']
|
| 17 |
+
|
| 18 |
+
# experimental - replace Dense layers with TreeConnect
|
| 19 |
+
use_treeconnect = False
|
| 20 |
+
treeconnect_threshold = 1024
|
| 21 |
+
|
| 22 |
+
#----------------------------------------------------------------------------
|
dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
from . import submission
|
| 9 |
+
|
| 10 |
+
from .submission.run_context import RunContext
|
| 11 |
+
|
| 12 |
+
from .submission.submit import SubmitTarget
|
| 13 |
+
from .submission.submit import PathType
|
| 14 |
+
from .submission.submit import SubmitConfig
|
| 15 |
+
from .submission.submit import get_path_from_template
|
| 16 |
+
from .submission.submit import submit_run
|
| 17 |
+
|
| 18 |
+
from .util import EasyDict
|
| 19 |
+
|
| 20 |
+
submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
|
dnnlib/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (468 Bytes). View file
|
|
|
dnnlib/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (498 Bytes). View file
|
|
|
dnnlib/__pycache__/util.cpython-36.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
dnnlib/__pycache__/util.cpython-37.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
dnnlib/submission/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
from . import run_context
|
| 9 |
+
from . import submit
|
dnnlib/submission/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
dnnlib/submission/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
dnnlib/submission/__pycache__/run_context.cpython-36.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
dnnlib/submission/__pycache__/run_context.cpython-37.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
dnnlib/submission/__pycache__/submit.cpython-36.pyc
ADDED
|
Binary file (9.19 kB). View file
|
|
|
dnnlib/submission/__pycache__/submit.cpython-37.pyc
ADDED
|
Binary file (9.19 kB). View file
|
|
|
dnnlib/submission/_internal/run.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Helper for launching run functions in computing clusters.
|
| 9 |
+
|
| 10 |
+
During the submit process, this file is copied to the appropriate run dir.
|
| 11 |
+
When the job is launched in the cluster, this module is the first thing that
|
| 12 |
+
is run inside the docker container.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import pickle
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
# PYTHONPATH should have been set so that the run_dir/src is in it
|
| 20 |
+
import dnnlib
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
if not len(sys.argv) >= 4:
|
| 24 |
+
raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
|
| 25 |
+
|
| 26 |
+
run_dir = str(sys.argv[1])
|
| 27 |
+
task_name = str(sys.argv[2])
|
| 28 |
+
host_name = str(sys.argv[3])
|
| 29 |
+
|
| 30 |
+
submit_config_path = os.path.join(run_dir, "submit_config.pkl")
|
| 31 |
+
|
| 32 |
+
# SubmitConfig should have been pickled to the run dir
|
| 33 |
+
if not os.path.exists(submit_config_path):
|
| 34 |
+
raise RuntimeError("SubmitConfig pickle file does not exist!")
|
| 35 |
+
|
| 36 |
+
submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
|
| 37 |
+
dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
|
| 38 |
+
|
| 39 |
+
submit_config.task_name = task_name
|
| 40 |
+
submit_config.host_name = host_name
|
| 41 |
+
|
| 42 |
+
dnnlib.submission.submit.run_wrapper(submit_config)
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
dnnlib/submission/run_context.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Helpers for managing the run/training loop."""
|
| 9 |
+
|
| 10 |
+
import datetime
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import pprint
|
| 14 |
+
import time
|
| 15 |
+
import types
|
| 16 |
+
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from . import submit
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RunContext(object):
|
| 23 |
+
"""Helper class for managing the run/training loop.
|
| 24 |
+
|
| 25 |
+
The context will hide the implementation details of a basic run/training loop.
|
| 26 |
+
It will set things up properly, tell if run should be stopped, and then cleans up.
|
| 27 |
+
User should call update periodically and use should_stop to determine if run should be stopped.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
submit_config: The SubmitConfig that is used for the current run.
|
| 31 |
+
config_module: The whole config module that is used for the current run.
|
| 32 |
+
max_epoch: Optional cached value for the max_epoch variable used in update.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
|
| 36 |
+
self.submit_config = submit_config
|
| 37 |
+
self.should_stop_flag = False
|
| 38 |
+
self.has_closed = False
|
| 39 |
+
self.start_time = time.time()
|
| 40 |
+
self.last_update_time = time.time()
|
| 41 |
+
self.last_update_interval = 0.0
|
| 42 |
+
self.max_epoch = max_epoch
|
| 43 |
+
|
| 44 |
+
# pretty print the all the relevant content of the config module to a text file
|
| 45 |
+
if config_module is not None:
|
| 46 |
+
with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
|
| 47 |
+
filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
|
| 48 |
+
pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
|
| 49 |
+
|
| 50 |
+
# write out details about the run to a text file
|
| 51 |
+
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
|
| 52 |
+
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
|
| 53 |
+
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
|
| 54 |
+
|
| 55 |
+
def __enter__(self) -> "RunContext":
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 59 |
+
self.close()
|
| 60 |
+
|
| 61 |
+
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
|
| 62 |
+
"""Do general housekeeping and keep the state of the context up-to-date.
|
| 63 |
+
Should be called often enough but not in a tight loop."""
|
| 64 |
+
assert not self.has_closed
|
| 65 |
+
|
| 66 |
+
self.last_update_interval = time.time() - self.last_update_time
|
| 67 |
+
self.last_update_time = time.time()
|
| 68 |
+
|
| 69 |
+
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
|
| 70 |
+
self.should_stop_flag = True
|
| 71 |
+
|
| 72 |
+
max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
|
| 73 |
+
|
| 74 |
+
def should_stop(self) -> bool:
|
| 75 |
+
"""Tell whether a stopping condition has been triggered one way or another."""
|
| 76 |
+
return self.should_stop_flag
|
| 77 |
+
|
| 78 |
+
def get_time_since_start(self) -> float:
|
| 79 |
+
"""How much time has passed since the creation of the context."""
|
| 80 |
+
return time.time() - self.start_time
|
| 81 |
+
|
| 82 |
+
def get_time_since_last_update(self) -> float:
|
| 83 |
+
"""How much time has passed since the last call to update."""
|
| 84 |
+
return time.time() - self.last_update_time
|
| 85 |
+
|
| 86 |
+
def get_last_update_interval(self) -> float:
|
| 87 |
+
"""How much time passed between the previous two calls to update."""
|
| 88 |
+
return self.last_update_interval
|
| 89 |
+
|
| 90 |
+
def close(self) -> None:
|
| 91 |
+
"""Close the context and clean up.
|
| 92 |
+
Should only be called once."""
|
| 93 |
+
if not self.has_closed:
|
| 94 |
+
# update the run.txt with stopping time
|
| 95 |
+
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
|
| 96 |
+
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
|
| 97 |
+
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
|
| 98 |
+
|
| 99 |
+
self.has_closed = True
|
dnnlib/submission/submit.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Submit a function to be run either locally or in a computing cluster."""
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import io
|
| 12 |
+
import os
|
| 13 |
+
import pathlib
|
| 14 |
+
import pickle
|
| 15 |
+
import platform
|
| 16 |
+
import pprint
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import time
|
| 20 |
+
import traceback
|
| 21 |
+
|
| 22 |
+
import zipfile
|
| 23 |
+
|
| 24 |
+
from enum import Enum
|
| 25 |
+
|
| 26 |
+
from .. import util
|
| 27 |
+
from ..util import EasyDict
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SubmitTarget(Enum):
|
| 31 |
+
"""The target where the function should be run.
|
| 32 |
+
|
| 33 |
+
LOCAL: Run it locally.
|
| 34 |
+
"""
|
| 35 |
+
LOCAL = 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PathType(Enum):
|
| 39 |
+
"""Determines in which format should a path be formatted.
|
| 40 |
+
|
| 41 |
+
WINDOWS: Format with Windows style.
|
| 42 |
+
LINUX: Format with Linux/Posix style.
|
| 43 |
+
AUTO: Use current OS type to select either WINDOWS or LINUX.
|
| 44 |
+
"""
|
| 45 |
+
WINDOWS = 1
|
| 46 |
+
LINUX = 2
|
| 47 |
+
AUTO = 3
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_user_name_override = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SubmitConfig(util.EasyDict):
|
| 54 |
+
"""Strongly typed config dict needed to submit runs.
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
|
| 58 |
+
run_desc: Description of the run. Will be used in the run dir and task name.
|
| 59 |
+
run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
|
| 60 |
+
run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
|
| 61 |
+
submit_target: Submit target enum value. Used to select where the run is actually launched.
|
| 62 |
+
num_gpus: Number of GPUs used/requested for the run.
|
| 63 |
+
print_info: Whether to print debug information when submitting.
|
| 64 |
+
ask_confirmation: Whether to ask a confirmation before submitting.
|
| 65 |
+
run_id: Automatically populated value during submit.
|
| 66 |
+
run_name: Automatically populated value during submit.
|
| 67 |
+
run_dir: Automatically populated value during submit.
|
| 68 |
+
run_func_name: Automatically populated value during submit.
|
| 69 |
+
run_func_kwargs: Automatically populated value during submit.
|
| 70 |
+
user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
|
| 71 |
+
task_name: Automatically populated value during submit.
|
| 72 |
+
host_name: Automatically populated value during submit.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
# run (set these)
|
| 79 |
+
self.run_dir_root = "" # should always be passed through get_path_from_template
|
| 80 |
+
self.run_desc = ""
|
| 81 |
+
self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
|
| 82 |
+
self.run_dir_extra_files = None
|
| 83 |
+
|
| 84 |
+
# submit (set these)
|
| 85 |
+
self.submit_target = SubmitTarget.LOCAL
|
| 86 |
+
self.num_gpus = 1
|
| 87 |
+
self.print_info = False
|
| 88 |
+
self.ask_confirmation = False
|
| 89 |
+
|
| 90 |
+
# (automatically populated)
|
| 91 |
+
self.run_id = None
|
| 92 |
+
self.run_name = None
|
| 93 |
+
self.run_dir = None
|
| 94 |
+
self.run_func_name = None
|
| 95 |
+
self.run_func_kwargs = None
|
| 96 |
+
self.user_name = None
|
| 97 |
+
self.task_name = None
|
| 98 |
+
self.host_name = "localhost"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
|
| 102 |
+
"""Replace tags in the given path template and return either Windows or Linux formatted path."""
|
| 103 |
+
# automatically select path type depending on running OS
|
| 104 |
+
if path_type == PathType.AUTO:
|
| 105 |
+
if platform.system() == "Windows":
|
| 106 |
+
path_type = PathType.WINDOWS
|
| 107 |
+
elif platform.system() == "Linux":
|
| 108 |
+
path_type = PathType.LINUX
|
| 109 |
+
else:
|
| 110 |
+
raise RuntimeError("Unknown platform")
|
| 111 |
+
|
| 112 |
+
path_template = path_template.replace("<USERNAME>", get_user_name())
|
| 113 |
+
|
| 114 |
+
# return correctly formatted path
|
| 115 |
+
if path_type == PathType.WINDOWS:
|
| 116 |
+
return str(pathlib.PureWindowsPath(path_template))
|
| 117 |
+
elif path_type == PathType.LINUX:
|
| 118 |
+
return str(pathlib.PurePosixPath(path_template))
|
| 119 |
+
else:
|
| 120 |
+
raise RuntimeError("Unknown platform")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_template_from_path(path: str) -> str:
|
| 124 |
+
"""Convert a normal path back to its template representation."""
|
| 125 |
+
# replace all path parts with the template tags
|
| 126 |
+
path = path.replace("\\", "/")
|
| 127 |
+
return path
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
|
| 131 |
+
"""Convert a normal path to template and the convert it back to a normal path with given path type."""
|
| 132 |
+
path_template = get_template_from_path(path)
|
| 133 |
+
path = get_path_from_template(path_template, path_type)
|
| 134 |
+
return path
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def set_user_name_override(name: str) -> None:
|
| 138 |
+
"""Set the global username override value."""
|
| 139 |
+
global _user_name_override
|
| 140 |
+
_user_name_override = name
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def get_user_name():
|
| 144 |
+
"""Get the current user name."""
|
| 145 |
+
if _user_name_override is not None:
|
| 146 |
+
return _user_name_override
|
| 147 |
+
elif platform.system() == "Windows":
|
| 148 |
+
return os.getlogin()
|
| 149 |
+
elif platform.system() == "Linux":
|
| 150 |
+
try:
|
| 151 |
+
import pwd # pylint: disable=import-error
|
| 152 |
+
return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
|
| 153 |
+
except:
|
| 154 |
+
return "unknown"
|
| 155 |
+
else:
|
| 156 |
+
raise RuntimeError("Unknown platform")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _create_run_dir_local(submit_config: SubmitConfig) -> str:
|
| 160 |
+
"""Create a new run dir with increasing ID number at the start."""
|
| 161 |
+
run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
|
| 162 |
+
|
| 163 |
+
if not os.path.exists(run_dir_root):
|
| 164 |
+
print("Creating the run dir root: {}".format(run_dir_root))
|
| 165 |
+
os.makedirs(run_dir_root)
|
| 166 |
+
|
| 167 |
+
submit_config.run_id = _get_next_run_id_local(run_dir_root)
|
| 168 |
+
submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
|
| 169 |
+
run_dir = os.path.join(run_dir_root, submit_config.run_name)
|
| 170 |
+
|
| 171 |
+
if os.path.exists(run_dir):
|
| 172 |
+
raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
|
| 173 |
+
|
| 174 |
+
print("Creating the run dir: {}".format(run_dir))
|
| 175 |
+
os.makedirs(run_dir)
|
| 176 |
+
|
| 177 |
+
return run_dir
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _get_next_run_id_local(run_dir_root: str) -> int:
|
| 181 |
+
"""Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
|
| 182 |
+
dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
|
| 183 |
+
r = re.compile("^\\d+") # match one or more digits at the start of the string
|
| 184 |
+
run_id = 0
|
| 185 |
+
|
| 186 |
+
for dir_name in dir_names:
|
| 187 |
+
m = r.match(dir_name)
|
| 188 |
+
|
| 189 |
+
if m is not None:
|
| 190 |
+
i = int(m.group())
|
| 191 |
+
run_id = max(run_id, i + 1)
|
| 192 |
+
|
| 193 |
+
return run_id
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
|
| 197 |
+
"""Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
|
| 198 |
+
print("Copying files to the run dir")
|
| 199 |
+
files = []
|
| 200 |
+
|
| 201 |
+
run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
|
| 202 |
+
assert '.' in submit_config.run_func_name
|
| 203 |
+
for _idx in range(submit_config.run_func_name.count('.') - 1):
|
| 204 |
+
run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
|
| 205 |
+
files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
|
| 206 |
+
|
| 207 |
+
dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
|
| 208 |
+
files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
|
| 209 |
+
|
| 210 |
+
if submit_config.run_dir_extra_files is not None:
|
| 211 |
+
files += submit_config.run_dir_extra_files
|
| 212 |
+
|
| 213 |
+
files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
|
| 214 |
+
files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
|
| 215 |
+
|
| 216 |
+
util.copy_files_and_create_dirs(files)
|
| 217 |
+
|
| 218 |
+
pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
|
| 219 |
+
|
| 220 |
+
with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
|
| 221 |
+
pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def run_wrapper(submit_config: SubmitConfig) -> None:
|
| 225 |
+
"""Wrap the actual run function call for handling logging, exceptions, typing, etc."""
|
| 226 |
+
is_local = submit_config.submit_target == SubmitTarget.LOCAL
|
| 227 |
+
|
| 228 |
+
checker = None
|
| 229 |
+
|
| 230 |
+
# when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
|
| 231 |
+
if is_local:
|
| 232 |
+
logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
|
| 233 |
+
else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
|
| 234 |
+
logger = util.Logger(file_name=None, should_flush=True)
|
| 235 |
+
|
| 236 |
+
import dnnlib
|
| 237 |
+
dnnlib.submit_config = submit_config
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
|
| 241 |
+
start_time = time.time()
|
| 242 |
+
util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
|
| 243 |
+
print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
|
| 244 |
+
except:
|
| 245 |
+
if is_local:
|
| 246 |
+
raise
|
| 247 |
+
else:
|
| 248 |
+
traceback.print_exc()
|
| 249 |
+
|
| 250 |
+
log_src = os.path.join(submit_config.run_dir, "log.txt")
|
| 251 |
+
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
|
| 252 |
+
shutil.copyfile(log_src, log_dst)
|
| 253 |
+
finally:
|
| 254 |
+
open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
|
| 255 |
+
|
| 256 |
+
dnnlib.submit_config = None
|
| 257 |
+
logger.close()
|
| 258 |
+
|
| 259 |
+
if checker is not None:
|
| 260 |
+
checker.stop()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
|
| 264 |
+
"""Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
|
| 265 |
+
submit_config = copy.copy(submit_config)
|
| 266 |
+
|
| 267 |
+
if submit_config.user_name is None:
|
| 268 |
+
submit_config.user_name = get_user_name()
|
| 269 |
+
|
| 270 |
+
submit_config.run_func_name = run_func_name
|
| 271 |
+
submit_config.run_func_kwargs = run_func_kwargs
|
| 272 |
+
|
| 273 |
+
assert submit_config.submit_target == SubmitTarget.LOCAL
|
| 274 |
+
if submit_config.submit_target in {SubmitTarget.LOCAL}:
|
| 275 |
+
run_dir = _create_run_dir_local(submit_config)
|
| 276 |
+
|
| 277 |
+
submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
|
| 278 |
+
submit_config.run_dir = run_dir
|
| 279 |
+
_populate_run_dir(run_dir, submit_config)
|
| 280 |
+
|
| 281 |
+
if submit_config.print_info:
|
| 282 |
+
print("\nSubmit config:\n")
|
| 283 |
+
pprint.pprint(submit_config, indent=4, width=200, compact=False)
|
| 284 |
+
print()
|
| 285 |
+
|
| 286 |
+
if submit_config.ask_confirmation:
|
| 287 |
+
if not util.ask_yes_no("Continue submitting the job?"):
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
run_wrapper(submit_config)
|
dnnlib/tflib/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
from . import autosummary
|
| 9 |
+
from . import network
|
| 10 |
+
from . import optimizer
|
| 11 |
+
from . import tfutil
|
| 12 |
+
|
| 13 |
+
from .tfutil import *
|
| 14 |
+
from .network import Network
|
| 15 |
+
|
| 16 |
+
from .optimizer import Optimizer
|
dnnlib/tflib/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
dnnlib/tflib/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (326 Bytes). View file
|
|
|
dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
dnnlib/tflib/__pycache__/network.cpython-36.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
dnnlib/tflib/__pycache__/network.cpython-37.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc
ADDED
|
Binary file (8.52 kB). View file
|
|
|
dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc
ADDED
|
Binary file (8.47 kB). View file
|
|
|
dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
dnnlib/tflib/autosummary.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Helper for adding automatically tracked values to Tensorboard.
|
| 9 |
+
|
| 10 |
+
Autosummary creates an identity op that internally keeps track of the input
|
| 11 |
+
values and automatically shows up in TensorBoard. The reported value
|
| 12 |
+
represents an average over input components. The average is accumulated
|
| 13 |
+
constantly over time and flushed when save_summaries() is called.
|
| 14 |
+
|
| 15 |
+
Notes:
|
| 16 |
+
- The output tensor must be used as an input for something else in the
|
| 17 |
+
graph. Otherwise, the autosummary op will not get executed, and the average
|
| 18 |
+
value will not get accumulated.
|
| 19 |
+
- It is perfectly fine to include autosummaries with the same name in
|
| 20 |
+
several places throughout the graph, even if they are executed concurrently.
|
| 21 |
+
- It is ok to also pass in a python scalar or numpy array. In this case, it
|
| 22 |
+
is added to the average immediately.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
import numpy as np
|
| 27 |
+
import tensorflow as tf
|
| 28 |
+
from tensorboard import summary as summary_lib
|
| 29 |
+
from tensorboard.plugins.custom_scalar import layout_pb2
|
| 30 |
+
|
| 31 |
+
from . import tfutil
|
| 32 |
+
from .tfutil import TfExpression
|
| 33 |
+
from .tfutil import TfExpressionEx
|
| 34 |
+
|
| 35 |
+
_dtype = tf.float64
|
| 36 |
+
_vars = OrderedDict() # name => [var, ...]
|
| 37 |
+
_immediate = OrderedDict() # name => update_op, update_value
|
| 38 |
+
_finalized = False
|
| 39 |
+
_merge_op = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
|
| 43 |
+
"""Internal helper for creating autosummary accumulators."""
|
| 44 |
+
assert not _finalized
|
| 45 |
+
name_id = name.replace("/", "_")
|
| 46 |
+
v = tf.cast(value_expr, _dtype)
|
| 47 |
+
|
| 48 |
+
if v.shape.is_fully_defined():
|
| 49 |
+
size = np.prod(tfutil.shape_to_list(v.shape))
|
| 50 |
+
size_expr = tf.constant(size, dtype=_dtype)
|
| 51 |
+
else:
|
| 52 |
+
size = None
|
| 53 |
+
size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
|
| 54 |
+
|
| 55 |
+
if size == 1:
|
| 56 |
+
if v.shape.ndims != 0:
|
| 57 |
+
v = tf.reshape(v, [])
|
| 58 |
+
v = [size_expr, v, tf.square(v)]
|
| 59 |
+
else:
|
| 60 |
+
v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
|
| 61 |
+
v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
|
| 62 |
+
|
| 63 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
|
| 64 |
+
var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
|
| 65 |
+
update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
|
| 66 |
+
|
| 67 |
+
if name in _vars:
|
| 68 |
+
_vars[name].append(var)
|
| 69 |
+
else:
|
| 70 |
+
_vars[name] = [var]
|
| 71 |
+
return update_op
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
|
| 75 |
+
"""Create a new autosummary.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
name: Name to use in TensorBoard
|
| 79 |
+
value: TensorFlow expression or python value to track
|
| 80 |
+
passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
|
| 81 |
+
|
| 82 |
+
Example use of the passthru mechanism:
|
| 83 |
+
|
| 84 |
+
n = autosummary('l2loss', loss, passthru=n)
|
| 85 |
+
|
| 86 |
+
This is a shorthand for the following code:
|
| 87 |
+
|
| 88 |
+
with tf.control_dependencies([autosummary('l2loss', loss)]):
|
| 89 |
+
n = tf.identity(n)
|
| 90 |
+
"""
|
| 91 |
+
tfutil.assert_tf_initialized()
|
| 92 |
+
name_id = name.replace("/", "_")
|
| 93 |
+
|
| 94 |
+
if tfutil.is_tf_expression(value):
|
| 95 |
+
with tf.name_scope("summary_" + name_id), tf.device(value.device):
|
| 96 |
+
update_op = _create_var(name, value)
|
| 97 |
+
with tf.control_dependencies([update_op]):
|
| 98 |
+
return tf.identity(value if passthru is None else passthru)
|
| 99 |
+
|
| 100 |
+
else: # python scalar or numpy array
|
| 101 |
+
if name not in _immediate:
|
| 102 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
|
| 103 |
+
update_value = tf.placeholder(_dtype)
|
| 104 |
+
update_op = _create_var(name, update_value)
|
| 105 |
+
_immediate[name] = update_op, update_value
|
| 106 |
+
|
| 107 |
+
update_op, update_value = _immediate[name]
|
| 108 |
+
tfutil.run(update_op, {update_value: value})
|
| 109 |
+
return value if passthru is None else passthru
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def finalize_autosummaries() -> None:
|
| 113 |
+
"""Create the necessary ops to include autosummaries in TensorBoard report.
|
| 114 |
+
Note: This should be done only once per graph.
|
| 115 |
+
"""
|
| 116 |
+
global _finalized
|
| 117 |
+
tfutil.assert_tf_initialized()
|
| 118 |
+
|
| 119 |
+
if _finalized:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
_finalized = True
|
| 123 |
+
tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
|
| 124 |
+
|
| 125 |
+
# Create summary ops.
|
| 126 |
+
with tf.device(None), tf.control_dependencies(None):
|
| 127 |
+
for name, vars_list in _vars.items():
|
| 128 |
+
name_id = name.replace("/", "_")
|
| 129 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id):
|
| 130 |
+
moments = tf.add_n(vars_list)
|
| 131 |
+
moments /= moments[0]
|
| 132 |
+
with tf.control_dependencies([moments]): # read before resetting
|
| 133 |
+
reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
|
| 134 |
+
with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
|
| 135 |
+
mean = moments[1]
|
| 136 |
+
std = tf.sqrt(moments[2] - tf.square(moments[1]))
|
| 137 |
+
tf.summary.scalar(name, mean)
|
| 138 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
|
| 139 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
|
| 140 |
+
|
| 141 |
+
# Group by category and chart name.
|
| 142 |
+
cat_dict = OrderedDict()
|
| 143 |
+
for series_name in sorted(_vars.keys()):
|
| 144 |
+
p = series_name.split("/")
|
| 145 |
+
cat = p[0] if len(p) >= 2 else ""
|
| 146 |
+
chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
|
| 147 |
+
if cat not in cat_dict:
|
| 148 |
+
cat_dict[cat] = OrderedDict()
|
| 149 |
+
if chart not in cat_dict[cat]:
|
| 150 |
+
cat_dict[cat][chart] = []
|
| 151 |
+
cat_dict[cat][chart].append(series_name)
|
| 152 |
+
|
| 153 |
+
# Setup custom_scalar layout.
|
| 154 |
+
categories = []
|
| 155 |
+
for cat_name, chart_dict in cat_dict.items():
|
| 156 |
+
charts = []
|
| 157 |
+
for chart_name, series_names in chart_dict.items():
|
| 158 |
+
series = []
|
| 159 |
+
for series_name in series_names:
|
| 160 |
+
series.append(layout_pb2.MarginChartContent.Series(
|
| 161 |
+
value=series_name,
|
| 162 |
+
lower="xCustomScalars/" + series_name + "/margin_lo",
|
| 163 |
+
upper="xCustomScalars/" + series_name + "/margin_hi"))
|
| 164 |
+
margin = layout_pb2.MarginChartContent(series=series)
|
| 165 |
+
charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
|
| 166 |
+
categories.append(layout_pb2.Category(title=cat_name, chart=charts))
|
| 167 |
+
layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
|
| 168 |
+
return layout
|
| 169 |
+
|
| 170 |
+
def save_summaries(file_writer, global_step=None):
|
| 171 |
+
"""Call FileWriter.add_summary() with all summaries in the default graph,
|
| 172 |
+
automatically finalizing and merging them on the first call.
|
| 173 |
+
"""
|
| 174 |
+
global _merge_op
|
| 175 |
+
tfutil.assert_tf_initialized()
|
| 176 |
+
|
| 177 |
+
if _merge_op is None:
|
| 178 |
+
layout = finalize_autosummaries()
|
| 179 |
+
if layout is not None:
|
| 180 |
+
file_writer.add_summary(layout)
|
| 181 |
+
with tf.device(None), tf.control_dependencies(None):
|
| 182 |
+
_merge_op = tf.summary.merge_all()
|
| 183 |
+
|
| 184 |
+
file_writer.add_summary(_merge_op.eval(), global_step)
|
dnnlib/tflib/network.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Helper for managing networks."""
|
| 9 |
+
|
| 10 |
+
import types
|
| 11 |
+
import inspect
|
| 12 |
+
import re
|
| 13 |
+
import uuid
|
| 14 |
+
import sys
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Any, List, Tuple, Union
|
| 20 |
+
|
| 21 |
+
from . import tfutil
|
| 22 |
+
from .. import util
|
| 23 |
+
|
| 24 |
+
from .tfutil import TfExpression, TfExpressionEx
|
| 25 |
+
|
| 26 |
+
_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
|
| 27 |
+
_import_module_src = dict() # Source code for temporary modules created during pickle import.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def import_handler(handler_func):
|
| 31 |
+
"""Function decorator for declaring custom import handlers."""
|
| 32 |
+
_import_handlers.append(handler_func)
|
| 33 |
+
return handler_func
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Network:
|
| 37 |
+
"""Generic network abstraction.
|
| 38 |
+
|
| 39 |
+
Acts as a convenience wrapper for a parameterized network construction
|
| 40 |
+
function, providing several utility methods and convenient access to
|
| 41 |
+
the inputs/outputs/weights.
|
| 42 |
+
|
| 43 |
+
Network objects can be safely pickled and unpickled for long-term
|
| 44 |
+
archival purposes. The pickling works reliably as long as the underlying
|
| 45 |
+
network construction function is defined in a standalone Python module
|
| 46 |
+
that has no side effects or application-specific imports.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
name: Network name. Used to select TensorFlow name and variable scopes.
|
| 50 |
+
func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
|
| 51 |
+
static_kwargs: Keyword arguments to be passed in to the network construction function.
|
| 52 |
+
|
| 53 |
+
Attributes:
|
| 54 |
+
name: User-specified name, defaults to build func name if None.
|
| 55 |
+
scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
|
| 56 |
+
static_kwargs: Arguments passed to the user-supplied build func.
|
| 57 |
+
components: Container for sub-networks. Passed to the build func, and retained between calls.
|
| 58 |
+
num_inputs: Number of input tensors.
|
| 59 |
+
num_outputs: Number of output tensors.
|
| 60 |
+
input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
|
| 61 |
+
output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
|
| 62 |
+
input_shape: Short-hand for input_shapes[0].
|
| 63 |
+
output_shape: Short-hand for output_shapes[0].
|
| 64 |
+
input_templates: Input placeholders in the template graph.
|
| 65 |
+
output_templates: Output tensors in the template graph.
|
| 66 |
+
input_names: Name string for each input.
|
| 67 |
+
output_names: Name string for each output.
|
| 68 |
+
own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
|
| 69 |
+
vars: All variables (local_name => var).
|
| 70 |
+
trainables: All trainable variables (local_name => var).
|
| 71 |
+
var_global_to_local: Mapping from variable global names to local names.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
|
| 75 |
+
tfutil.assert_tf_initialized()
|
| 76 |
+
assert isinstance(name, str) or name is None
|
| 77 |
+
assert func_name is not None
|
| 78 |
+
assert isinstance(func_name, str) or util.is_top_level_function(func_name)
|
| 79 |
+
assert util.is_pickleable(static_kwargs)
|
| 80 |
+
|
| 81 |
+
self._init_fields()
|
| 82 |
+
self.name = name
|
| 83 |
+
self.static_kwargs = util.EasyDict(static_kwargs)
|
| 84 |
+
|
| 85 |
+
# Locate the user-specified network build function.
|
| 86 |
+
if util.is_top_level_function(func_name):
|
| 87 |
+
func_name = util.get_top_level_function_name(func_name)
|
| 88 |
+
module, self._build_func_name = util.get_module_from_obj_name(func_name)
|
| 89 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
| 90 |
+
assert callable(self._build_func)
|
| 91 |
+
|
| 92 |
+
# Dig up source code for the module containing the build function.
|
| 93 |
+
self._build_module_src = _import_module_src.get(module, None)
|
| 94 |
+
if self._build_module_src is None:
|
| 95 |
+
self._build_module_src = inspect.getsource(module)
|
| 96 |
+
|
| 97 |
+
# Init TensorFlow graph.
|
| 98 |
+
self._init_graph()
|
| 99 |
+
self.reset_own_vars()
|
| 100 |
+
|
| 101 |
+
def _init_fields(self) -> None:
|
| 102 |
+
self.name = None
|
| 103 |
+
self.scope = None
|
| 104 |
+
self.static_kwargs = util.EasyDict()
|
| 105 |
+
self.components = util.EasyDict()
|
| 106 |
+
self.num_inputs = 0
|
| 107 |
+
self.num_outputs = 0
|
| 108 |
+
self.input_shapes = [[]]
|
| 109 |
+
self.output_shapes = [[]]
|
| 110 |
+
self.input_shape = []
|
| 111 |
+
self.output_shape = []
|
| 112 |
+
self.input_templates = []
|
| 113 |
+
self.output_templates = []
|
| 114 |
+
self.input_names = []
|
| 115 |
+
self.output_names = []
|
| 116 |
+
self.own_vars = OrderedDict()
|
| 117 |
+
self.vars = OrderedDict()
|
| 118 |
+
self.trainables = OrderedDict()
|
| 119 |
+
self.var_global_to_local = OrderedDict()
|
| 120 |
+
|
| 121 |
+
self._build_func = None # User-supplied build function that constructs the network.
|
| 122 |
+
self._build_func_name = None # Name of the build function.
|
| 123 |
+
self._build_module_src = None # Full source code of the module containing the build function.
|
| 124 |
+
self._run_cache = dict() # Cached graph data for Network.run().
|
| 125 |
+
|
| 126 |
+
def _init_graph(self) -> None:
|
| 127 |
+
# Collect inputs.
|
| 128 |
+
self.input_names = []
|
| 129 |
+
|
| 130 |
+
for param in inspect.signature(self._build_func).parameters.values():
|
| 131 |
+
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
|
| 132 |
+
self.input_names.append(param.name)
|
| 133 |
+
|
| 134 |
+
self.num_inputs = len(self.input_names)
|
| 135 |
+
assert self.num_inputs >= 1
|
| 136 |
+
|
| 137 |
+
# Choose name and scope.
|
| 138 |
+
if self.name is None:
|
| 139 |
+
self.name = self._build_func_name
|
| 140 |
+
assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
|
| 141 |
+
with tf.name_scope(None):
|
| 142 |
+
self.scope = tf.compat.v1.get_default_graph().unique_name(self.name, mark_as_used=True)
|
| 143 |
+
|
| 144 |
+
# Finalize build func kwargs.
|
| 145 |
+
build_kwargs = dict(self.static_kwargs)
|
| 146 |
+
build_kwargs["is_template_graph"] = True
|
| 147 |
+
build_kwargs["components"] = self.components
|
| 148 |
+
|
| 149 |
+
# Build template graph.
|
| 150 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=tf.compat.v1.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
|
| 151 |
+
assert tf.compat.v1.get_variable_scope().name == self.scope
|
| 152 |
+
assert tf.compat.v1.get_default_graph().get_name_scope() == self.scope
|
| 153 |
+
with tf.control_dependencies(None): # ignore surrounding control dependencies
|
| 154 |
+
self.input_templates = [tf.compat.v1.placeholder(tf.float32, name=name) for name in self.input_names]
|
| 155 |
+
out_expr = self._build_func(*self.input_templates, **build_kwargs)
|
| 156 |
+
|
| 157 |
+
# Collect outputs.
|
| 158 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
| 159 |
+
self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
| 160 |
+
self.num_outputs = len(self.output_templates)
|
| 161 |
+
assert self.num_outputs >= 1
|
| 162 |
+
assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
|
| 163 |
+
|
| 164 |
+
# Perform sanity checks.
|
| 165 |
+
if any(t.shape.ndims is None for t in self.input_templates):
|
| 166 |
+
raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
|
| 167 |
+
if any(t.shape.ndims is None for t in self.output_templates):
|
| 168 |
+
raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
|
| 169 |
+
if any(not isinstance(comp, Network) for comp in self.components.values()):
|
| 170 |
+
raise ValueError("Components of a Network must be Networks themselves.")
|
| 171 |
+
if len(self.components) != len(set(comp.name for comp in self.components.values())):
|
| 172 |
+
raise ValueError("Components of a Network must have unique names.")
|
| 173 |
+
|
| 174 |
+
# List inputs and outputs.
|
| 175 |
+
self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
|
| 176 |
+
self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
|
| 177 |
+
self.input_shape = self.input_shapes[0]
|
| 178 |
+
self.output_shape = self.output_shapes[0]
|
| 179 |
+
self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
|
| 180 |
+
|
| 181 |
+
# List variables.
|
| 182 |
+
self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.compat.v1.global_variables(self.scope + "/"))
|
| 183 |
+
self.vars = OrderedDict(self.own_vars)
|
| 184 |
+
self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
|
| 185 |
+
self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
|
| 186 |
+
self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
|
| 187 |
+
|
| 188 |
+
def reset_own_vars(self) -> None:
|
| 189 |
+
"""Re-initialize all variables of this network, excluding sub-networks."""
|
| 190 |
+
tfutil.run([var.initializer for var in self.own_vars.values()])
|
| 191 |
+
|
| 192 |
+
def reset_vars(self) -> None:
|
| 193 |
+
"""Re-initialize all variables of this network, including sub-networks."""
|
| 194 |
+
tfutil.run([var.initializer for var in self.vars.values()])
|
| 195 |
+
|
| 196 |
+
def reset_trainables(self) -> None:
|
| 197 |
+
"""Re-initialize all trainable variables of this network, including sub-networks."""
|
| 198 |
+
tfutil.run([var.initializer for var in self.trainables.values()])
|
| 199 |
+
|
| 200 |
+
def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
|
| 201 |
+
"""Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
|
| 202 |
+
assert len(in_expr) == self.num_inputs
|
| 203 |
+
assert not all(expr is None for expr in in_expr)
|
| 204 |
+
|
| 205 |
+
# Finalize build func kwargs.
|
| 206 |
+
build_kwargs = dict(self.static_kwargs)
|
| 207 |
+
build_kwargs.update(dynamic_kwargs)
|
| 208 |
+
build_kwargs["is_template_graph"] = False
|
| 209 |
+
build_kwargs["components"] = self.components
|
| 210 |
+
|
| 211 |
+
# Build TensorFlow graph to evaluate the network.
|
| 212 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
|
| 213 |
+
assert tf.compat.v1.get_variable_scope().name == self.scope
|
| 214 |
+
valid_inputs = [expr for expr in in_expr if expr is not None]
|
| 215 |
+
final_inputs = []
|
| 216 |
+
for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
|
| 217 |
+
if expr is not None:
|
| 218 |
+
expr = tf.identity(expr, name=name)
|
| 219 |
+
else:
|
| 220 |
+
expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
|
| 221 |
+
final_inputs.append(expr)
|
| 222 |
+
out_expr = self._build_func(*final_inputs, **build_kwargs)
|
| 223 |
+
|
| 224 |
+
# Propagate input shapes back to the user-specified expressions.
|
| 225 |
+
for expr, final in zip(in_expr, final_inputs):
|
| 226 |
+
if isinstance(expr, tf.Tensor):
|
| 227 |
+
expr.set_shape(final.shape)
|
| 228 |
+
|
| 229 |
+
# Express outputs in the desired format.
|
| 230 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
| 231 |
+
if return_as_list:
|
| 232 |
+
out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
| 233 |
+
return out_expr
|
| 234 |
+
|
| 235 |
+
def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
|
| 236 |
+
"""Get the local name of a given variable, without any surrounding name scopes."""
|
| 237 |
+
assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
|
| 238 |
+
global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
|
| 239 |
+
return self.var_global_to_local[global_name]
|
| 240 |
+
|
| 241 |
+
def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
|
| 242 |
+
"""Find variable by local or global name."""
|
| 243 |
+
assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
|
| 244 |
+
return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
|
| 245 |
+
|
| 246 |
+
def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
|
| 247 |
+
"""Get the value of a given variable as NumPy array.
|
| 248 |
+
Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
|
| 249 |
+
return self.find_var(var_or_local_name).eval()
|
| 250 |
+
|
| 251 |
+
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
|
| 252 |
+
"""Set the value of a given variable based on the given NumPy array.
|
| 253 |
+
Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
|
| 254 |
+
tfutil.set_vars({self.find_var(var_or_local_name): new_value})
|
| 255 |
+
|
| 256 |
+
def __getstate__(self) -> dict:
|
| 257 |
+
"""Pickle export."""
|
| 258 |
+
state = dict()
|
| 259 |
+
state["version"] = 3
|
| 260 |
+
state["name"] = self.name
|
| 261 |
+
state["static_kwargs"] = dict(self.static_kwargs)
|
| 262 |
+
state["components"] = dict(self.components)
|
| 263 |
+
state["build_module_src"] = self._build_module_src
|
| 264 |
+
state["build_func_name"] = self._build_func_name
|
| 265 |
+
state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
|
| 266 |
+
return state
|
| 267 |
+
|
| 268 |
+
def __setstate__(self, state: dict) -> None:
|
| 269 |
+
"""Pickle import."""
|
| 270 |
+
# pylint: disable=attribute-defined-outside-init
|
| 271 |
+
tfutil.assert_tf_initialized()
|
| 272 |
+
self._init_fields()
|
| 273 |
+
|
| 274 |
+
# Execute custom import handlers.
|
| 275 |
+
for handler in _import_handlers:
|
| 276 |
+
state = handler(state)
|
| 277 |
+
|
| 278 |
+
# Set basic fields.
|
| 279 |
+
assert state["version"] in [2, 3]
|
| 280 |
+
self.name = state["name"]
|
| 281 |
+
self.static_kwargs = util.EasyDict(state["static_kwargs"])
|
| 282 |
+
self.components = util.EasyDict(state.get("components", {}))
|
| 283 |
+
self._build_module_src = state["build_module_src"]
|
| 284 |
+
self._build_func_name = state["build_func_name"]
|
| 285 |
+
|
| 286 |
+
# Create temporary module from the imported source code.
|
| 287 |
+
module_name = "_tflib_network_import_" + uuid.uuid4().hex
|
| 288 |
+
module = types.ModuleType(module_name)
|
| 289 |
+
sys.modules[module_name] = module
|
| 290 |
+
_import_module_src[module] = self._build_module_src
|
| 291 |
+
exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
|
| 292 |
+
|
| 293 |
+
# Locate network build function in the temporary module.
|
| 294 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
| 295 |
+
assert callable(self._build_func)
|
| 296 |
+
|
| 297 |
+
# Init TensorFlow graph.
|
| 298 |
+
self._init_graph()
|
| 299 |
+
self.reset_own_vars()
|
| 300 |
+
tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
|
| 301 |
+
|
| 302 |
+
def clone(self, name: str = None, **new_static_kwargs) -> "Network":
|
| 303 |
+
"""Create a clone of this network with its own copy of the variables."""
|
| 304 |
+
# pylint: disable=protected-access
|
| 305 |
+
net = object.__new__(Network)
|
| 306 |
+
net._init_fields()
|
| 307 |
+
net.name = name if name is not None else self.name
|
| 308 |
+
net.static_kwargs = util.EasyDict(self.static_kwargs)
|
| 309 |
+
net.static_kwargs.update(new_static_kwargs)
|
| 310 |
+
net._build_module_src = self._build_module_src
|
| 311 |
+
net._build_func_name = self._build_func_name
|
| 312 |
+
net._build_func = self._build_func
|
| 313 |
+
net._init_graph()
|
| 314 |
+
net.copy_vars_from(self)
|
| 315 |
+
return net
|
| 316 |
+
|
| 317 |
+
def copy_own_vars_from(self, src_net: "Network") -> None:
|
| 318 |
+
"""Copy the values of all variables from the given network, excluding sub-networks."""
|
| 319 |
+
names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
|
| 320 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
| 321 |
+
|
| 322 |
+
def copy_vars_from(self, src_net: "Network") -> None:
|
| 323 |
+
"""Copy the values of all variables from the given network, including sub-networks."""
|
| 324 |
+
names = [name for name in self.vars.keys() if name in src_net.vars]
|
| 325 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
| 326 |
+
|
| 327 |
+
def copy_trainables_from(self, src_net: "Network") -> None:
|
| 328 |
+
"""Copy the values of all trainable variables from the given network, including sub-networks."""
|
| 329 |
+
names = [name for name in self.trainables.keys() if name in src_net.trainables]
|
| 330 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
| 331 |
+
|
| 332 |
+
def copy_compatible_trainables_from(self, src_net: "Network") -> None:
|
| 333 |
+
"""Copy the compatible values of all trainable variables from the given network, including sub-networks"""
|
| 334 |
+
names = []
|
| 335 |
+
for name in self.trainables.keys():
|
| 336 |
+
if name not in src_net.trainables:
|
| 337 |
+
print("Not restoring (not present): {}".format(name))
|
| 338 |
+
elif self.trainables[name].shape != src_net.trainables[name].shape:
|
| 339 |
+
print("Not restoring (different shape): {}".format(name))
|
| 340 |
+
|
| 341 |
+
if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
|
| 342 |
+
names.append(name)
|
| 343 |
+
|
| 344 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
| 345 |
+
|
| 346 |
+
def apply_swa(self, src_net, epoch):
|
| 347 |
+
"""Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
|
| 348 |
+
names = []
|
| 349 |
+
for name in self.trainables.keys():
|
| 350 |
+
if name not in src_net.trainables:
|
| 351 |
+
print("Not restoring (not present): {}".format(name))
|
| 352 |
+
elif self.trainables[name].shape != src_net.trainables[name].shape:
|
| 353 |
+
print("Not restoring (different shape): {}".format(name))
|
| 354 |
+
|
| 355 |
+
if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
|
| 356 |
+
names.append(name)
|
| 357 |
+
|
| 358 |
+
scale_new_data = 1.0 / (epoch + 1)
|
| 359 |
+
scale_moving_average = (1.0 - scale_new_data)
|
| 360 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))
|
| 361 |
+
|
| 362 |
+
def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
|
| 363 |
+
"""Create new network with the given parameters, and copy all variables from this network."""
|
| 364 |
+
if new_name is None:
|
| 365 |
+
new_name = self.name
|
| 366 |
+
static_kwargs = dict(self.static_kwargs)
|
| 367 |
+
static_kwargs.update(new_static_kwargs)
|
| 368 |
+
net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
|
| 369 |
+
net.copy_vars_from(self)
|
| 370 |
+
return net
|
| 371 |
+
|
| 372 |
+
def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
|
| 373 |
+
"""Construct a TensorFlow op that updates the variables of this network
|
| 374 |
+
to be slightly closer to those of the given network."""
|
| 375 |
+
with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
|
| 376 |
+
ops = []
|
| 377 |
+
for name, var in self.vars.items():
|
| 378 |
+
if name in src_net.vars:
|
| 379 |
+
cur_beta = beta if name in self.trainables else beta_nontrainable
|
| 380 |
+
new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
|
| 381 |
+
ops.append(var.assign(new_value))
|
| 382 |
+
return tf.group(*ops)
|
| 383 |
+
|
| 384 |
+
def run(self,
|
| 385 |
+
*in_arrays: Tuple[Union[np.ndarray, None], ...],
|
| 386 |
+
input_transform: dict = None,
|
| 387 |
+
output_transform: dict = None,
|
| 388 |
+
return_as_list: bool = False,
|
| 389 |
+
print_progress: bool = False,
|
| 390 |
+
minibatch_size: int = None,
|
| 391 |
+
num_gpus: int = 1,
|
| 392 |
+
assume_frozen: bool = False,
|
| 393 |
+
custom_inputs=None,
|
| 394 |
+
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
|
| 395 |
+
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
|
| 399 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the input
|
| 400 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
| 401 |
+
output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
|
| 402 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the output
|
| 403 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
| 404 |
+
return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
|
| 405 |
+
print_progress: Print progress to the console? Useful for very large input arrays.
|
| 406 |
+
minibatch_size: Maximum minibatch size to use, None = disable batching.
|
| 407 |
+
num_gpus: Number of GPUs to use.
|
| 408 |
+
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
|
| 409 |
+
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
|
| 410 |
+
custom_inputs: Allow to use another Tensor as input instead of default Placeholders
|
| 411 |
+
"""
|
| 412 |
+
assert len(in_arrays) == self.num_inputs
|
| 413 |
+
assert not all(arr is None for arr in in_arrays)
|
| 414 |
+
assert input_transform is None or util.is_top_level_function(input_transform["func"])
|
| 415 |
+
assert output_transform is None or util.is_top_level_function(output_transform["func"])
|
| 416 |
+
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
|
| 417 |
+
num_items = in_arrays[0].shape[0]
|
| 418 |
+
if minibatch_size is None:
|
| 419 |
+
minibatch_size = num_items
|
| 420 |
+
|
| 421 |
+
# Construct unique hash key from all arguments that affect the TensorFlow graph.
|
| 422 |
+
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
|
| 423 |
+
def unwind_key(obj):
|
| 424 |
+
if isinstance(obj, dict):
|
| 425 |
+
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
|
| 426 |
+
if callable(obj):
|
| 427 |
+
return util.get_top_level_function_name(obj)
|
| 428 |
+
return obj
|
| 429 |
+
key = repr(unwind_key(key))
|
| 430 |
+
|
| 431 |
+
# Build graph.
|
| 432 |
+
if key not in self._run_cache:
|
| 433 |
+
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
|
| 434 |
+
if custom_inputs is not None:
|
| 435 |
+
with tf.device("/gpu:0"):
|
| 436 |
+
in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
|
| 437 |
+
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
|
| 438 |
+
else:
|
| 439 |
+
with tf.device("/cpu:0"):
|
| 440 |
+
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
|
| 441 |
+
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
|
| 442 |
+
|
| 443 |
+
out_split = []
|
| 444 |
+
for gpu in range(num_gpus):
|
| 445 |
+
with tf.device("/gpu:%d" % gpu):
|
| 446 |
+
net_gpu = self.clone() if assume_frozen else self
|
| 447 |
+
in_gpu = in_split[gpu]
|
| 448 |
+
|
| 449 |
+
if input_transform is not None:
|
| 450 |
+
in_kwargs = dict(input_transform)
|
| 451 |
+
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
|
| 452 |
+
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
|
| 453 |
+
|
| 454 |
+
assert len(in_gpu) == self.num_inputs
|
| 455 |
+
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
|
| 456 |
+
|
| 457 |
+
if output_transform is not None:
|
| 458 |
+
out_kwargs = dict(output_transform)
|
| 459 |
+
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
|
| 460 |
+
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
|
| 461 |
+
|
| 462 |
+
assert len(out_gpu) == self.num_outputs
|
| 463 |
+
out_split.append(out_gpu)
|
| 464 |
+
|
| 465 |
+
with tf.device("/cpu:0"):
|
| 466 |
+
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
|
| 467 |
+
self._run_cache[key] = in_expr, out_expr
|
| 468 |
+
|
| 469 |
+
# Run minibatches.
|
| 470 |
+
in_expr, out_expr = self._run_cache[key]
|
| 471 |
+
out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
|
| 472 |
+
|
| 473 |
+
for mb_begin in range(0, num_items, minibatch_size):
|
| 474 |
+
if print_progress:
|
| 475 |
+
print("\r%d / %d" % (mb_begin, num_items), end="")
|
| 476 |
+
|
| 477 |
+
mb_end = min(mb_begin + minibatch_size, num_items)
|
| 478 |
+
mb_num = mb_end - mb_begin
|
| 479 |
+
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
|
| 480 |
+
mb_out = tf.compat.v1.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
|
| 481 |
+
|
| 482 |
+
for dst, src in zip(out_arrays, mb_out):
|
| 483 |
+
dst[mb_begin: mb_end] = src
|
| 484 |
+
|
| 485 |
+
# Done.
|
| 486 |
+
if print_progress:
|
| 487 |
+
print("\r%d / %d" % (num_items, num_items))
|
| 488 |
+
|
| 489 |
+
if not return_as_list:
|
| 490 |
+
out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
|
| 491 |
+
return out_arrays
|
| 492 |
+
|
| 493 |
+
def list_ops(self) -> List[TfExpression]:
|
| 494 |
+
include_prefix = self.scope + "/"
|
| 495 |
+
exclude_prefix = include_prefix + "_"
|
| 496 |
+
ops = tf.get_default_graph().get_operations()
|
| 497 |
+
ops = [op for op in ops if op.name.startswith(include_prefix)]
|
| 498 |
+
ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
|
| 499 |
+
return ops
|
| 500 |
+
|
| 501 |
+
def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
|
| 502 |
+
"""Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
|
| 503 |
+
individual layers of the network. Mainly intended to be used for reporting."""
|
| 504 |
+
layers = []
|
| 505 |
+
|
| 506 |
+
def recurse(scope, parent_ops, parent_vars, level):
|
| 507 |
+
# Ignore specific patterns.
|
| 508 |
+
if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
# Filter ops and vars by scope.
|
| 512 |
+
global_prefix = scope + "/"
|
| 513 |
+
local_prefix = global_prefix[len(self.scope) + 1:]
|
| 514 |
+
cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
|
| 515 |
+
cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
|
| 516 |
+
if not cur_ops and not cur_vars:
|
| 517 |
+
return
|
| 518 |
+
|
| 519 |
+
# Filter out all ops related to variables.
|
| 520 |
+
for var in [op for op in cur_ops if op.type.startswith("Variable")]:
|
| 521 |
+
var_prefix = var.name + "/"
|
| 522 |
+
cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
|
| 523 |
+
|
| 524 |
+
# Scope does not contain ops as immediate children => recurse deeper.
|
| 525 |
+
contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
|
| 526 |
+
if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
|
| 527 |
+
visited = set()
|
| 528 |
+
for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
|
| 529 |
+
token = rel_name.split("/")[0]
|
| 530 |
+
if token not in visited:
|
| 531 |
+
recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
|
| 532 |
+
visited.add(token)
|
| 533 |
+
return
|
| 534 |
+
|
| 535 |
+
# Report layer.
|
| 536 |
+
layer_name = scope[len(self.scope) + 1:]
|
| 537 |
+
layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
|
| 538 |
+
layer_trainables = [var for _name, var in cur_vars if var.trainable]
|
| 539 |
+
layers.append((layer_name, layer_output, layer_trainables))
|
| 540 |
+
|
| 541 |
+
recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
|
| 542 |
+
return layers
|
| 543 |
+
|
| 544 |
+
def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
|
| 545 |
+
"""Print a summary table of the network structure."""
|
| 546 |
+
rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
|
| 547 |
+
rows += [["---"] * 4]
|
| 548 |
+
total_params = 0
|
| 549 |
+
|
| 550 |
+
for layer_name, layer_output, layer_trainables in self.list_layers():
|
| 551 |
+
num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
|
| 552 |
+
weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")]
|
| 553 |
+
weights.sort(key=lambda x: len(x.name))
|
| 554 |
+
if len(weights) == 0 and len(layer_trainables) == 1:
|
| 555 |
+
weights = layer_trainables
|
| 556 |
+
total_params += num_params
|
| 557 |
+
|
| 558 |
+
if not hide_layers_with_no_params or num_params != 0:
|
| 559 |
+
num_params_str = str(num_params) if num_params > 0 else "-"
|
| 560 |
+
output_shape_str = str(layer_output.shape)
|
| 561 |
+
weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
|
| 562 |
+
rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
|
| 563 |
+
|
| 564 |
+
rows += [["---"] * 4]
|
| 565 |
+
rows += [["Total", str(total_params), "", ""]]
|
| 566 |
+
|
| 567 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
| 568 |
+
print()
|
| 569 |
+
for row in rows:
|
| 570 |
+
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
|
| 571 |
+
print()
|
| 572 |
+
|
| 573 |
+
def setup_weight_histograms(self, title: str = None) -> None:
|
| 574 |
+
"""Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
|
| 575 |
+
if title is None:
|
| 576 |
+
title = self.name
|
| 577 |
+
|
| 578 |
+
with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
|
| 579 |
+
for local_name, var in self.trainables.items():
|
| 580 |
+
if "/" in local_name:
|
| 581 |
+
p = local_name.split("/")
|
| 582 |
+
name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
|
| 583 |
+
else:
|
| 584 |
+
name = title + "_toplevel/" + local_name
|
| 585 |
+
|
| 586 |
+
tf.summary.histogram(name, var)
|
| 587 |
+
|
| 588 |
+
#----------------------------------------------------------------------------
|
| 589 |
+
# Backwards-compatible emulation of legacy output transformation in Network.run().
|
| 590 |
+
|
| 591 |
+
_print_legacy_warning = True
|
| 592 |
+
|
| 593 |
+
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
|
| 594 |
+
global _print_legacy_warning
|
| 595 |
+
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
|
| 596 |
+
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
|
| 597 |
+
return output_transform, dynamic_kwargs
|
| 598 |
+
|
| 599 |
+
if _print_legacy_warning:
|
| 600 |
+
_print_legacy_warning = False
|
| 601 |
+
print()
|
| 602 |
+
print("WARNING: Old-style output transformations in Network.run() are deprecated.")
|
| 603 |
+
print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
|
| 604 |
+
print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
|
| 605 |
+
print()
|
| 606 |
+
assert output_transform is None
|
| 607 |
+
|
| 608 |
+
new_kwargs = dict(dynamic_kwargs)
|
| 609 |
+
new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
|
| 610 |
+
new_transform["func"] = _legacy_output_transform_func
|
| 611 |
+
return new_transform, new_kwargs
|
| 612 |
+
|
| 613 |
+
def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
|
| 614 |
+
if out_mul != 1.0:
|
| 615 |
+
expr = [x * out_mul for x in expr]
|
| 616 |
+
|
| 617 |
+
if out_add != 0.0:
|
| 618 |
+
expr = [x + out_add for x in expr]
|
| 619 |
+
|
| 620 |
+
if out_shrink > 1:
|
| 621 |
+
ksize = [1, 1, out_shrink, out_shrink]
|
| 622 |
+
expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
|
| 623 |
+
|
| 624 |
+
if out_dtype is not None:
|
| 625 |
+
if tf.as_dtype(out_dtype).is_integer:
|
| 626 |
+
expr = [tf.round(x) for x in expr]
|
| 627 |
+
expr = [tf.saturate_cast(x, out_dtype) for x in expr]
|
| 628 |
+
return expr
|
dnnlib/tflib/optimizer.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Helper wrapper for a Tensorflow optimizer."""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from typing import List, Union
|
| 15 |
+
|
| 16 |
+
from . import autosummary
|
| 17 |
+
from . import tfutil
|
| 18 |
+
from .. import util
|
| 19 |
+
|
| 20 |
+
from .tfutil import TfExpression, TfExpressionEx
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# TensorFlow 1.13
|
| 24 |
+
from tensorflow.python.ops import nccl_ops
|
| 25 |
+
except:
|
| 26 |
+
# Older TensorFlow versions
|
| 27 |
+
import tensorflow.contrib.nccl as nccl_ops
|
| 28 |
+
|
| 29 |
+
class Optimizer:
|
| 30 |
+
"""A Wrapper for tf.train.Optimizer.
|
| 31 |
+
|
| 32 |
+
Automatically takes care of:
|
| 33 |
+
- Gradient averaging for multi-GPU training.
|
| 34 |
+
- Dynamic loss scaling and typecasts for FP16 training.
|
| 35 |
+
- Ignoring corrupted gradients that contain NaNs/Infs.
|
| 36 |
+
- Reporting statistics.
|
| 37 |
+
- Well-chosen default settings.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self,
|
| 41 |
+
name: str = "Train",
|
| 42 |
+
tf_optimizer: str = "tf.train.AdamOptimizer",
|
| 43 |
+
learning_rate: TfExpressionEx = 0.001,
|
| 44 |
+
use_loss_scaling: bool = False,
|
| 45 |
+
loss_scaling_init: float = 64.0,
|
| 46 |
+
loss_scaling_inc: float = 0.0005,
|
| 47 |
+
loss_scaling_dec: float = 1.0,
|
| 48 |
+
**kwargs):
|
| 49 |
+
|
| 50 |
+
# Init fields.
|
| 51 |
+
self.name = name
|
| 52 |
+
self.learning_rate = tf.convert_to_tensor(learning_rate)
|
| 53 |
+
self.id = self.name.replace("/", ".")
|
| 54 |
+
self.scope = tf.get_default_graph().unique_name(self.id)
|
| 55 |
+
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
|
| 56 |
+
self.optimizer_kwargs = dict(kwargs)
|
| 57 |
+
self.use_loss_scaling = use_loss_scaling
|
| 58 |
+
self.loss_scaling_init = loss_scaling_init
|
| 59 |
+
self.loss_scaling_inc = loss_scaling_inc
|
| 60 |
+
self.loss_scaling_dec = loss_scaling_dec
|
| 61 |
+
self._grad_shapes = None # [shape, ...]
|
| 62 |
+
self._dev_opt = OrderedDict() # device => optimizer
|
| 63 |
+
self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
|
| 64 |
+
self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
|
| 65 |
+
self._updates_applied = False
|
| 66 |
+
|
| 67 |
+
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
|
| 68 |
+
"""Register the gradients of the given loss function with respect to the given variables.
|
| 69 |
+
Intended to be called once per GPU."""
|
| 70 |
+
assert not self._updates_applied
|
| 71 |
+
|
| 72 |
+
# Validate arguments.
|
| 73 |
+
if isinstance(trainable_vars, dict):
|
| 74 |
+
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
|
| 75 |
+
|
| 76 |
+
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
|
| 77 |
+
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
|
| 78 |
+
|
| 79 |
+
if self._grad_shapes is None:
|
| 80 |
+
self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
|
| 81 |
+
|
| 82 |
+
assert len(trainable_vars) == len(self._grad_shapes)
|
| 83 |
+
assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
|
| 84 |
+
|
| 85 |
+
dev = loss.device
|
| 86 |
+
|
| 87 |
+
assert all(var.device == dev for var in trainable_vars)
|
| 88 |
+
|
| 89 |
+
# Register device and compute gradients.
|
| 90 |
+
with tf.name_scope(self.id + "_grad"), tf.device(dev):
|
| 91 |
+
if dev not in self._dev_opt:
|
| 92 |
+
opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
|
| 93 |
+
assert callable(self.optimizer_class)
|
| 94 |
+
self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
|
| 95 |
+
self._dev_grads[dev] = []
|
| 96 |
+
|
| 97 |
+
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
|
| 98 |
+
grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
|
| 99 |
+
grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
|
| 100 |
+
self._dev_grads[dev].append(grads)
|
| 101 |
+
|
| 102 |
+
def apply_updates(self) -> tf.Operation:
|
| 103 |
+
"""Construct training op to update the registered variables based on their gradients."""
|
| 104 |
+
tfutil.assert_tf_initialized()
|
| 105 |
+
assert not self._updates_applied
|
| 106 |
+
self._updates_applied = True
|
| 107 |
+
devices = list(self._dev_grads.keys())
|
| 108 |
+
total_grads = sum(len(grads) for grads in self._dev_grads.values())
|
| 109 |
+
assert len(devices) >= 1 and total_grads >= 1
|
| 110 |
+
ops = []
|
| 111 |
+
|
| 112 |
+
with tfutil.absolute_name_scope(self.scope):
|
| 113 |
+
# Cast gradients to FP32 and calculate partial sum within each device.
|
| 114 |
+
dev_grads = OrderedDict() # device => [(grad, var), ...]
|
| 115 |
+
|
| 116 |
+
for dev_idx, dev in enumerate(devices):
|
| 117 |
+
with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
|
| 118 |
+
sums = []
|
| 119 |
+
|
| 120 |
+
for gv in zip(*self._dev_grads[dev]):
|
| 121 |
+
assert all(v is gv[0][1] for g, v in gv)
|
| 122 |
+
g = [tf.cast(g, tf.float32) for g, v in gv]
|
| 123 |
+
g = g[0] if len(g) == 1 else tf.add_n(g)
|
| 124 |
+
sums.append((g, gv[0][1]))
|
| 125 |
+
|
| 126 |
+
dev_grads[dev] = sums
|
| 127 |
+
|
| 128 |
+
# Sum gradients across devices.
|
| 129 |
+
if len(devices) > 1:
|
| 130 |
+
with tf.name_scope("SumAcrossGPUs"), tf.device(None):
|
| 131 |
+
for var_idx, grad_shape in enumerate(self._grad_shapes):
|
| 132 |
+
g = [dev_grads[dev][var_idx][0] for dev in devices]
|
| 133 |
+
|
| 134 |
+
if np.prod(grad_shape): # nccl does not support zero-sized tensors
|
| 135 |
+
g = nccl_ops.all_sum(g)
|
| 136 |
+
|
| 137 |
+
for dev, gg in zip(devices, g):
|
| 138 |
+
dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
|
| 139 |
+
|
| 140 |
+
# Apply updates separately on each device.
|
| 141 |
+
for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
|
| 142 |
+
with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
|
| 143 |
+
# Scale gradients as needed.
|
| 144 |
+
if self.use_loss_scaling or total_grads > 1:
|
| 145 |
+
with tf.name_scope("Scale"):
|
| 146 |
+
coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
|
| 147 |
+
coef = self.undo_loss_scaling(coef)
|
| 148 |
+
grads = [(g * coef, v) for g, v in grads]
|
| 149 |
+
|
| 150 |
+
# Check for overflows.
|
| 151 |
+
with tf.name_scope("CheckOverflow"):
|
| 152 |
+
grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
|
| 153 |
+
|
| 154 |
+
# Update weights and adjust loss scaling.
|
| 155 |
+
with tf.name_scope("UpdateWeights"):
|
| 156 |
+
# pylint: disable=cell-var-from-loop
|
| 157 |
+
opt = self._dev_opt[dev]
|
| 158 |
+
ls_var = self.get_loss_scaling_var(dev)
|
| 159 |
+
|
| 160 |
+
if not self.use_loss_scaling:
|
| 161 |
+
ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
|
| 162 |
+
else:
|
| 163 |
+
ops.append(tf.cond(grad_ok,
|
| 164 |
+
lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
|
| 165 |
+
lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
|
| 166 |
+
|
| 167 |
+
# Report statistics on the last device.
|
| 168 |
+
if dev == devices[-1]:
|
| 169 |
+
with tf.name_scope("Statistics"):
|
| 170 |
+
ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
|
| 171 |
+
ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
|
| 172 |
+
|
| 173 |
+
if self.use_loss_scaling:
|
| 174 |
+
ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
|
| 175 |
+
|
| 176 |
+
# Initialize variables and group everything into a single op.
|
| 177 |
+
self.reset_optimizer_state()
|
| 178 |
+
tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
|
| 179 |
+
|
| 180 |
+
return tf.group(*ops, name="TrainingOp")
|
| 181 |
+
|
| 182 |
+
def reset_optimizer_state(self) -> None:
|
| 183 |
+
"""Reset internal state of the underlying optimizer."""
|
| 184 |
+
tfutil.assert_tf_initialized()
|
| 185 |
+
tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
|
| 186 |
+
|
| 187 |
+
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
|
| 188 |
+
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
|
| 189 |
+
if not self.use_loss_scaling:
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
if device not in self._dev_ls_var:
|
| 193 |
+
with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
|
| 194 |
+
self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
|
| 195 |
+
|
| 196 |
+
return self._dev_ls_var[device]
|
| 197 |
+
|
| 198 |
+
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
|
| 199 |
+
"""Apply dynamic loss scaling for the given expression."""
|
| 200 |
+
assert tfutil.is_tf_expression(value)
|
| 201 |
+
|
| 202 |
+
if not self.use_loss_scaling:
|
| 203 |
+
return value
|
| 204 |
+
|
| 205 |
+
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
|
| 206 |
+
|
| 207 |
+
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
|
| 208 |
+
"""Undo the effect of dynamic loss scaling for the given expression."""
|
| 209 |
+
assert tfutil.is_tf_expression(value)
|
| 210 |
+
|
| 211 |
+
if not self.use_loss_scaling:
|
| 212 |
+
return value
|
| 213 |
+
|
| 214 |
+
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
|
dnnlib/tflib/tfutil.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous helper utils for Tensorflow."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import tensorflow as tf
|
| 13 |
+
|
| 14 |
+
from typing import Any, Iterable, List, Union
|
| 15 |
+
|
| 16 |
+
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
|
| 17 |
+
"""A type that represents a valid Tensorflow expression."""
|
| 18 |
+
|
| 19 |
+
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
|
| 20 |
+
"""A type that can be converted to a valid Tensorflow expression."""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run(*args, **kwargs) -> Any:
|
| 24 |
+
"""Run the specified ops in the default session."""
|
| 25 |
+
assert_tf_initialized()
|
| 26 |
+
return tf.compat.v1.get_default_session().run(*args, **kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_tf_expression(x: Any) -> bool:
|
| 30 |
+
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
|
| 31 |
+
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def shape_to_list(shape: Iterable[tf.compat.v1.Dimension]) -> List[Union[int, None]]:
|
| 35 |
+
"""Convert a Tensorflow shape to a list of ints."""
|
| 36 |
+
return [dim.value for dim in shape]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def flatten(x: TfExpressionEx) -> TfExpression:
|
| 40 |
+
"""Shortcut function for flattening a tensor."""
|
| 41 |
+
with tf.name_scope("Flatten"):
|
| 42 |
+
return tf.reshape(x, [-1])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def log2(x: TfExpressionEx) -> TfExpression:
|
| 46 |
+
"""Logarithm in base 2."""
|
| 47 |
+
with tf.name_scope("Log2"):
|
| 48 |
+
return tf.log(x) * np.float32(1.0 / np.log(2.0))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def exp2(x: TfExpressionEx) -> TfExpression:
|
| 52 |
+
"""Exponent in base 2."""
|
| 53 |
+
with tf.name_scope("Exp2"):
|
| 54 |
+
return tf.exp(x * np.float32(np.log(2.0)))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
|
| 58 |
+
"""Linear interpolation."""
|
| 59 |
+
with tf.name_scope("Lerp"):
|
| 60 |
+
return a + (b - a) * t
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
|
| 64 |
+
"""Linear interpolation with clip."""
|
| 65 |
+
with tf.name_scope("LerpClip"):
|
| 66 |
+
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def absolute_name_scope(scope: str) -> tf.name_scope:
|
| 70 |
+
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
|
| 71 |
+
return tf.name_scope(scope + "/")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def absolute_variable_scope(scope: str, **kwargs) -> tf.compat.v1.variable_scope:
|
| 75 |
+
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
|
| 76 |
+
return tf.compat.v1.variable_scope(tf.compat.v1.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _sanitize_tf_config(config_dict: dict = None) -> dict:
|
| 80 |
+
# Defaults.
|
| 81 |
+
cfg = dict()
|
| 82 |
+
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
|
| 83 |
+
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
|
| 84 |
+
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
|
| 85 |
+
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
|
| 86 |
+
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
|
| 87 |
+
|
| 88 |
+
# User overrides.
|
| 89 |
+
if config_dict is not None:
|
| 90 |
+
cfg.update(config_dict)
|
| 91 |
+
return cfg
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def init_tf(config_dict: dict = None) -> None:
|
| 95 |
+
"""Initialize TensorFlow session using good default settings."""
|
| 96 |
+
# Skip if already initialized.
|
| 97 |
+
if tf.compat.v1.get_default_session() is not None:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
# Setup config dict and random seeds.
|
| 101 |
+
cfg = _sanitize_tf_config(config_dict)
|
| 102 |
+
np_random_seed = cfg["rnd.np_random_seed"]
|
| 103 |
+
if np_random_seed is not None:
|
| 104 |
+
np.random.seed(np_random_seed)
|
| 105 |
+
tf_random_seed = cfg["rnd.tf_random_seed"]
|
| 106 |
+
if tf_random_seed == "auto":
|
| 107 |
+
tf_random_seed = np.random.randint(1 << 31)
|
| 108 |
+
if tf_random_seed is not None:
|
| 109 |
+
tf.compat.v1.set_random_seed(tf_random_seed)
|
| 110 |
+
|
| 111 |
+
# Setup environment variables.
|
| 112 |
+
for key, value in list(cfg.items()):
|
| 113 |
+
fields = key.split(".")
|
| 114 |
+
if fields[0] == "env":
|
| 115 |
+
assert len(fields) == 2
|
| 116 |
+
os.environ[fields[1]] = str(value)
|
| 117 |
+
|
| 118 |
+
# Create default TensorFlow session.
|
| 119 |
+
create_session(cfg, force_as_default=True)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def assert_tf_initialized():
|
| 123 |
+
"""Check that TensorFlow session has been initialized."""
|
| 124 |
+
if tf.compat.v1.get_default_session() is None:
|
| 125 |
+
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.compat.v1.Session:
|
| 129 |
+
"""Create tf.Session based on config dict."""
|
| 130 |
+
# Setup TensorFlow config proto.
|
| 131 |
+
cfg = _sanitize_tf_config(config_dict)
|
| 132 |
+
config_proto = tf.compat.v1.ConfigProto()
|
| 133 |
+
for key, value in cfg.items():
|
| 134 |
+
fields = key.split(".")
|
| 135 |
+
if fields[0] not in ["rnd", "env"]:
|
| 136 |
+
obj = config_proto
|
| 137 |
+
for field in fields[:-1]:
|
| 138 |
+
obj = getattr(obj, field)
|
| 139 |
+
setattr(obj, fields[-1], value)
|
| 140 |
+
|
| 141 |
+
# Create session.
|
| 142 |
+
session = tf.compat.v1.Session(config=config_proto)
|
| 143 |
+
if force_as_default:
|
| 144 |
+
# pylint: disable=protected-access
|
| 145 |
+
session._default_session = session.as_default()
|
| 146 |
+
session._default_session.enforce_nesting = False
|
| 147 |
+
session._default_session.__enter__() # pylint: disable=no-member
|
| 148 |
+
|
| 149 |
+
return session
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
|
| 153 |
+
"""Initialize all tf.Variables that have not already been initialized.
|
| 154 |
+
|
| 155 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
| 156 |
+
tf.variables_initializer(tf.report_uninitialized_variables()).run()
|
| 157 |
+
"""
|
| 158 |
+
assert_tf_initialized()
|
| 159 |
+
if target_vars is None:
|
| 160 |
+
target_vars = tf.global_variables()
|
| 161 |
+
|
| 162 |
+
test_vars = []
|
| 163 |
+
test_ops = []
|
| 164 |
+
|
| 165 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
| 166 |
+
for var in target_vars:
|
| 167 |
+
assert is_tf_expression(var)
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
|
| 171 |
+
except KeyError:
|
| 172 |
+
# Op does not exist => variable may be uninitialized.
|
| 173 |
+
test_vars.append(var)
|
| 174 |
+
|
| 175 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
| 176 |
+
test_ops.append(tf.is_variable_initialized(var))
|
| 177 |
+
|
| 178 |
+
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
|
| 179 |
+
run([var.initializer for var in init_vars])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def set_vars(var_to_value_dict: dict) -> None:
|
| 183 |
+
"""Set the values of given tf.Variables.
|
| 184 |
+
|
| 185 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
| 186 |
+
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
|
| 187 |
+
"""
|
| 188 |
+
assert_tf_initialized()
|
| 189 |
+
ops = []
|
| 190 |
+
feed_dict = {}
|
| 191 |
+
|
| 192 |
+
for var, value in var_to_value_dict.items():
|
| 193 |
+
assert is_tf_expression(var)
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
setter = tf.compat.v1.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
|
| 197 |
+
except KeyError:
|
| 198 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
| 199 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
| 200 |
+
setter = tf.compat.v1.assign(var, tf.compat.v1.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
|
| 201 |
+
|
| 202 |
+
ops.append(setter)
|
| 203 |
+
feed_dict[setter.op.inputs[1]] = value
|
| 204 |
+
|
| 205 |
+
run(ops, feed_dict)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
|
| 209 |
+
"""Create tf.Variable with large initial value without bloating the tf graph."""
|
| 210 |
+
assert_tf_initialized()
|
| 211 |
+
assert isinstance(initial_value, np.ndarray)
|
| 212 |
+
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
|
| 213 |
+
var = tf.Variable(zeros, *args, **kwargs)
|
| 214 |
+
set_vars({var: initial_value})
|
| 215 |
+
return var
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
|
| 219 |
+
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
|
| 220 |
+
Can be used as an input transformation for Network.run().
|
| 221 |
+
"""
|
| 222 |
+
images = tf.cast(images, tf.float32)
|
| 223 |
+
if nhwc_to_nchw:
|
| 224 |
+
images = tf.transpose(images, [0, 3, 1, 2])
|
| 225 |
+
return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True):
|
| 229 |
+
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
|
| 230 |
+
Can be used as an output transformation for Network.run().
|
| 231 |
+
"""
|
| 232 |
+
images = tf.cast(images, tf.float32)
|
| 233 |
+
if shrink > 1:
|
| 234 |
+
ksize = [1, 1, shrink, shrink]
|
| 235 |
+
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
|
| 236 |
+
if nchw_to_nhwc:
|
| 237 |
+
images = tf.transpose(images, [0, 2, 3, 1])
|
| 238 |
+
scale = 255 / (drange[1] - drange[0])
|
| 239 |
+
images = images * scale + (0.5 - drange[0] * scale)
|
| 240 |
+
if uint8_cast:
|
| 241 |
+
images = tf.saturate_cast(images, tf.uint8)
|
| 242 |
+
return images
|
dnnlib/util.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous utility classes and functions."""
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import fnmatch
|
| 12 |
+
import importlib
|
| 13 |
+
import inspect
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
import io
|
| 20 |
+
import pickle
|
| 21 |
+
import re
|
| 22 |
+
import requests
|
| 23 |
+
import html
|
| 24 |
+
import hashlib
|
| 25 |
+
import glob
|
| 26 |
+
import uuid
|
| 27 |
+
|
| 28 |
+
from distutils.util import strtobool
|
| 29 |
+
from typing import Any, List, Tuple, Union
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Util classes
|
| 33 |
+
# ------------------------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class EasyDict(dict):
|
| 37 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 38 |
+
|
| 39 |
+
def __getattr__(self, name: str) -> Any:
|
| 40 |
+
try:
|
| 41 |
+
return self[name]
|
| 42 |
+
except KeyError:
|
| 43 |
+
raise AttributeError(name)
|
| 44 |
+
|
| 45 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 46 |
+
self[name] = value
|
| 47 |
+
|
| 48 |
+
def __delattr__(self, name: str) -> None:
|
| 49 |
+
del self[name]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Logger(object):
|
| 53 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
| 56 |
+
self.file = None
|
| 57 |
+
|
| 58 |
+
if file_name is not None:
|
| 59 |
+
self.file = open(file_name, file_mode)
|
| 60 |
+
|
| 61 |
+
self.should_flush = should_flush
|
| 62 |
+
self.stdout = sys.stdout
|
| 63 |
+
self.stderr = sys.stderr
|
| 64 |
+
|
| 65 |
+
sys.stdout = self
|
| 66 |
+
sys.stderr = self
|
| 67 |
+
|
| 68 |
+
def __enter__(self) -> "Logger":
|
| 69 |
+
return self
|
| 70 |
+
|
| 71 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 72 |
+
self.close()
|
| 73 |
+
|
| 74 |
+
def write(self, text: str) -> None:
|
| 75 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 76 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
if self.file is not None:
|
| 80 |
+
self.file.write(text)
|
| 81 |
+
|
| 82 |
+
self.stdout.write(text)
|
| 83 |
+
|
| 84 |
+
if self.should_flush:
|
| 85 |
+
self.flush()
|
| 86 |
+
|
| 87 |
+
def flush(self) -> None:
|
| 88 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 89 |
+
if self.file is not None:
|
| 90 |
+
self.file.flush()
|
| 91 |
+
|
| 92 |
+
self.stdout.flush()
|
| 93 |
+
|
| 94 |
+
def close(self) -> None:
|
| 95 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 96 |
+
self.flush()
|
| 97 |
+
|
| 98 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 99 |
+
if sys.stdout is self:
|
| 100 |
+
sys.stdout = self.stdout
|
| 101 |
+
if sys.stderr is self:
|
| 102 |
+
sys.stderr = self.stderr
|
| 103 |
+
|
| 104 |
+
if self.file is not None:
|
| 105 |
+
self.file.close()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Small util functions
|
| 109 |
+
# ------------------------------------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 113 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 114 |
+
s = int(np.rint(seconds))
|
| 115 |
+
|
| 116 |
+
if s < 60:
|
| 117 |
+
return "{0}s".format(s)
|
| 118 |
+
elif s < 60 * 60:
|
| 119 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 120 |
+
elif s < 24 * 60 * 60:
|
| 121 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 122 |
+
else:
|
| 123 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def ask_yes_no(question: str) -> bool:
|
| 127 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 128 |
+
while True:
|
| 129 |
+
try:
|
| 130 |
+
print("{0} [y/n]".format(question))
|
| 131 |
+
return strtobool(input().lower())
|
| 132 |
+
except ValueError:
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def tuple_product(t: Tuple) -> Any:
|
| 137 |
+
"""Calculate the product of the tuple elements."""
|
| 138 |
+
result = 1
|
| 139 |
+
|
| 140 |
+
for v in t:
|
| 141 |
+
result *= v
|
| 142 |
+
|
| 143 |
+
return result
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
_str_to_ctype = {
|
| 147 |
+
"uint8": ctypes.c_ubyte,
|
| 148 |
+
"uint16": ctypes.c_uint16,
|
| 149 |
+
"uint32": ctypes.c_uint32,
|
| 150 |
+
"uint64": ctypes.c_uint64,
|
| 151 |
+
"int8": ctypes.c_byte,
|
| 152 |
+
"int16": ctypes.c_int16,
|
| 153 |
+
"int32": ctypes.c_int32,
|
| 154 |
+
"int64": ctypes.c_int64,
|
| 155 |
+
"float32": ctypes.c_float,
|
| 156 |
+
"float64": ctypes.c_double
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 161 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 162 |
+
type_str = None
|
| 163 |
+
|
| 164 |
+
if isinstance(type_obj, str):
|
| 165 |
+
type_str = type_obj
|
| 166 |
+
elif hasattr(type_obj, "__name__"):
|
| 167 |
+
type_str = type_obj.__name__
|
| 168 |
+
elif hasattr(type_obj, "name"):
|
| 169 |
+
type_str = type_obj.name
|
| 170 |
+
else:
|
| 171 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 172 |
+
|
| 173 |
+
assert type_str in _str_to_ctype.keys()
|
| 174 |
+
|
| 175 |
+
my_dtype = np.dtype(type_str)
|
| 176 |
+
my_ctype = _str_to_ctype[type_str]
|
| 177 |
+
|
| 178 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 179 |
+
|
| 180 |
+
return my_dtype, my_ctype
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def is_pickleable(obj: Any) -> bool:
|
| 184 |
+
try:
|
| 185 |
+
with io.BytesIO() as stream:
|
| 186 |
+
pickle.dump(obj, stream)
|
| 187 |
+
return True
|
| 188 |
+
except:
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 193 |
+
# ------------------------------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 196 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 197 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 198 |
+
|
| 199 |
+
# allow convenience shorthands, substitute them by full names
|
| 200 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 201 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 202 |
+
|
| 203 |
+
# list alternatives for (module_name, local_obj_name)
|
| 204 |
+
parts = obj_name.split(".")
|
| 205 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 206 |
+
|
| 207 |
+
# try each alternative in turn
|
| 208 |
+
for module_name, local_obj_name in name_pairs:
|
| 209 |
+
try:
|
| 210 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 211 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 212 |
+
return module, local_obj_name
|
| 213 |
+
except:
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
# maybe some of the modules themselves contain errors?
|
| 217 |
+
for module_name, _local_obj_name in name_pairs:
|
| 218 |
+
try:
|
| 219 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 220 |
+
except ImportError:
|
| 221 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 222 |
+
raise
|
| 223 |
+
|
| 224 |
+
# maybe the requested attribute is missing?
|
| 225 |
+
for module_name, local_obj_name in name_pairs:
|
| 226 |
+
try:
|
| 227 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 228 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 229 |
+
except ImportError:
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
# we are out of luck, but we have no idea why
|
| 233 |
+
raise ImportError(obj_name)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 237 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 238 |
+
if obj_name == '':
|
| 239 |
+
return module
|
| 240 |
+
obj = module
|
| 241 |
+
for part in obj_name.split("."):
|
| 242 |
+
obj = getattr(obj, part)
|
| 243 |
+
return obj
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_obj_by_name(name: str) -> Any:
|
| 247 |
+
"""Finds the python object with the given name."""
|
| 248 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 249 |
+
return get_obj_from_module(module, obj_name)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 253 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 254 |
+
assert func_name is not None
|
| 255 |
+
func_obj = get_obj_by_name(func_name)
|
| 256 |
+
assert callable(func_obj)
|
| 257 |
+
return func_obj(*args, **kwargs)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 261 |
+
"""Get the directory path of the module containing the given object name."""
|
| 262 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 263 |
+
return os.path.dirname(inspect.getfile(module))
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 267 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 268 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 272 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 273 |
+
assert is_top_level_function(obj)
|
| 274 |
+
return obj.__module__ + "." + obj.__name__
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# File system helpers
|
| 278 |
+
# ------------------------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 281 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 282 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 283 |
+
assert os.path.isdir(dir_path)
|
| 284 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 285 |
+
|
| 286 |
+
if ignores is None:
|
| 287 |
+
ignores = []
|
| 288 |
+
|
| 289 |
+
result = []
|
| 290 |
+
|
| 291 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 292 |
+
for ignore_ in ignores:
|
| 293 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 294 |
+
|
| 295 |
+
# dirs need to be edited in-place
|
| 296 |
+
for d in dirs_to_remove:
|
| 297 |
+
dirs.remove(d)
|
| 298 |
+
|
| 299 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 300 |
+
|
| 301 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 302 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 303 |
+
|
| 304 |
+
if add_base_to_relative:
|
| 305 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 306 |
+
|
| 307 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 308 |
+
result += zip(absolute_paths, relative_paths)
|
| 309 |
+
|
| 310 |
+
return result
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 314 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 315 |
+
Will create all necessary directories."""
|
| 316 |
+
for file in files:
|
| 317 |
+
target_dir_name = os.path.dirname(file[1])
|
| 318 |
+
|
| 319 |
+
# will create all intermediate-level directories
|
| 320 |
+
if not os.path.exists(target_dir_name):
|
| 321 |
+
os.makedirs(target_dir_name)
|
| 322 |
+
|
| 323 |
+
shutil.copyfile(file[0], file[1])
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# URL helpers
|
| 327 |
+
# ------------------------------------------------------------------------------------------
|
| 328 |
+
|
| 329 |
+
def is_url(obj: Any) -> bool:
|
| 330 |
+
"""Determine whether the given object is a valid URL string."""
|
| 331 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 332 |
+
return False
|
| 333 |
+
try:
|
| 334 |
+
res = requests.compat.urlparse(obj)
|
| 335 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 336 |
+
return False
|
| 337 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 338 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 339 |
+
return False
|
| 340 |
+
except:
|
| 341 |
+
return False
|
| 342 |
+
return True
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
|
| 346 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 347 |
+
if not is_url(url) and os.path.isfile(url):
|
| 348 |
+
return open(url, 'rb')
|
| 349 |
+
|
| 350 |
+
assert is_url(url)
|
| 351 |
+
assert num_attempts >= 1
|
| 352 |
+
|
| 353 |
+
# Lookup from cache.
|
| 354 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 355 |
+
if cache_dir is not None:
|
| 356 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 357 |
+
if len(cache_files) == 1:
|
| 358 |
+
return open(cache_files[0], "rb")
|
| 359 |
+
|
| 360 |
+
# Download.
|
| 361 |
+
url_name = None
|
| 362 |
+
url_data = None
|
| 363 |
+
with requests.Session() as session:
|
| 364 |
+
if verbose:
|
| 365 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 366 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 367 |
+
try:
|
| 368 |
+
with session.get(url) as res:
|
| 369 |
+
res.raise_for_status()
|
| 370 |
+
if len(res.content) == 0:
|
| 371 |
+
raise IOError("No data received")
|
| 372 |
+
|
| 373 |
+
if len(res.content) < 8192:
|
| 374 |
+
content_str = res.content.decode("utf-8")
|
| 375 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 376 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 377 |
+
if len(links) == 1:
|
| 378 |
+
url = requests.compat.urljoin(url, links[0])
|
| 379 |
+
raise IOError("Google Drive virus checker nag")
|
| 380 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 381 |
+
raise IOError("Google Drive quota exceeded")
|
| 382 |
+
|
| 383 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 384 |
+
url_name = match[1] if match else url
|
| 385 |
+
url_data = res.content
|
| 386 |
+
if verbose:
|
| 387 |
+
print(" done")
|
| 388 |
+
break
|
| 389 |
+
except:
|
| 390 |
+
if not attempts_left:
|
| 391 |
+
if verbose:
|
| 392 |
+
print(" failed")
|
| 393 |
+
raise
|
| 394 |
+
if verbose:
|
| 395 |
+
print(".", end="", flush=True)
|
| 396 |
+
|
| 397 |
+
# Save to cache.
|
| 398 |
+
if cache_dir is not None:
|
| 399 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 400 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 401 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 402 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 403 |
+
with open(temp_file, "wb") as f:
|
| 404 |
+
f.write(url_data)
|
| 405 |
+
os.replace(temp_file, cache_file) # atomic
|
| 406 |
+
|
| 407 |
+
# Return data as file object.
|
| 408 |
+
return io.BytesIO(url_data)
|
encode_images.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import pickle
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import PIL.Image
|
| 6 |
+
from PIL import ImageFilter
|
| 7 |
+
import numpy as np
|
| 8 |
+
import dnnlib
|
| 9 |
+
import dnnlib.tflib as tflib
|
| 10 |
+
import config
|
| 11 |
+
from encoder.generator_model import Generator
|
| 12 |
+
from encoder.perceptual_model import PerceptualModel, load_images
|
| 13 |
+
#from tensorflow.keras.models import load_model
|
| 14 |
+
from keras.models import load_model
|
| 15 |
+
from keras.applications.resnet50 import preprocess_input
|
| 16 |
+
|
| 17 |
+
def split_to_batches(l, n):
|
| 18 |
+
for i in range(0, len(l), n):
|
| 19 |
+
yield l[i:i + n]
|
| 20 |
+
|
| 21 |
+
def str2bool(v):
|
| 22 |
+
if isinstance(v, bool):
|
| 23 |
+
return v
|
| 24 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 25 |
+
return True
|
| 26 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 27 |
+
return False
|
| 28 |
+
else:
|
| 29 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 30 |
+
|
| 31 |
+
def main():
|
| 32 |
+
parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 33 |
+
parser.add_argument('src_dir', help='Directory with images for encoding')
|
| 34 |
+
parser.add_argument('generated_images_dir', help='Directory for storing generated images')
|
| 35 |
+
parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
|
| 36 |
+
parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
|
| 37 |
+
parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
|
| 38 |
+
parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
|
| 39 |
+
parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
|
| 40 |
+
parser.add_argument('--model_url', default='./data/karras2019stylegan-ffhq-1024x1024.pkl', help='Fetch a StyleGAN model to train on from this URL')
|
| 41 |
+
parser.add_argument('--architecture', default='./data/vgg16_zhang_perceptual.pkl', help='Сonvolutional neural network model from this URL')
|
| 42 |
+
parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
|
| 43 |
+
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
|
| 44 |
+
parser.add_argument('--optimizer', default='ggt', help='Optimization algorithm used for optimizing dlatents')
|
| 45 |
+
|
| 46 |
+
# Perceptual model params
|
| 47 |
+
parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
|
| 48 |
+
parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)
|
| 49 |
+
parser.add_argument('--lr', default=0.25, help='Learning rate for perceptual model', type=float)
|
| 50 |
+
parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
|
| 51 |
+
parser.add_argument('--iterations', default=100, help='Number of optimization steps for each batch', type=int)
|
| 52 |
+
parser.add_argument('--decay_steps', default=4, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
|
| 53 |
+
parser.add_argument('--early_stopping', default=True, help='Stop early once training stabilizes', type=str2bool, nargs='?', const=True)
|
| 54 |
+
parser.add_argument('--early_stopping_threshold', default=0.5, help='Stop after this threshold has been reached', type=float)
|
| 55 |
+
parser.add_argument('--early_stopping_patience', default=10, help='Number of iterations to wait below threshold', type=int)
|
| 56 |
+
parser.add_argument('--load_effnet', default='data/finetuned_effnet.h5', help='Model to load for EfficientNet approximation of dlatents')
|
| 57 |
+
parser.add_argument('--load_resnet', default='data/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents')
|
| 58 |
+
parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True)
|
| 59 |
+
parser.add_argument('--use_best_loss', default=True, help='Output the lowest loss value found as the solution', type=str2bool, nargs='?', const=True)
|
| 60 |
+
parser.add_argument('--average_best_loss', default=0.25, help='Do a running weighted average with the previous best dlatents found', type=float)
|
| 61 |
+
parser.add_argument('--sharpen_input', default=True, help='Sharpen the input images', type=str2bool, nargs='?', const=True)
|
| 62 |
+
|
| 63 |
+
# Loss function options
|
| 64 |
+
parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
| 65 |
+
parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
|
| 66 |
+
parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
|
| 67 |
+
parser.add_argument('--use_mssim_loss', default=200, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
| 68 |
+
parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
| 69 |
+
parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)
|
| 70 |
+
parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float)
|
| 71 |
+
parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True)
|
| 72 |
+
|
| 73 |
+
# Generator params
|
| 74 |
+
parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True)
|
| 75 |
+
parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True)
|
| 76 |
+
parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)
|
| 77 |
+
|
| 78 |
+
# Masking params
|
| 79 |
+
parser.add_argument('--load_mask', default=False, help='Load segmentation masks', type=str2bool, nargs='?', const=True)
|
| 80 |
+
parser.add_argument('--face_mask', default=True, help='Generate a mask for predicting only the face area', type=str2bool, nargs='?', const=True)
|
| 81 |
+
parser.add_argument('--use_grabcut', default=True, help='Use grabcut algorithm on the face mask to better segment the foreground', type=str2bool, nargs='?', const=True)
|
| 82 |
+
parser.add_argument('--scale_mask', default=1.4, help='Look over a wider section of foreground for grabcut', type=float)
|
| 83 |
+
parser.add_argument('--composite_mask', default=True, help='Merge the unmasked area back into the generated image', type=str2bool, nargs='?', const=True)
|
| 84 |
+
parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int)
|
| 85 |
+
|
| 86 |
+
# Video params
|
| 87 |
+
parser.add_argument('--video_dir', default='videos', help='Directory for storing training videos')
|
| 88 |
+
parser.add_argument('--output_video', default=False, help='Generate videos of the optimization process', type=bool)
|
| 89 |
+
parser.add_argument('--video_codec', default='MJPG', help='FOURCC-supported video codec name')
|
| 90 |
+
parser.add_argument('--video_frame_rate', default=24, help='Video frames per second', type=int)
|
| 91 |
+
parser.add_argument('--video_size', default=512, help='Video size in pixels', type=int)
|
| 92 |
+
parser.add_argument('--video_skip', default=1, help='Only write every n frames (1 = write every frame)', type=int)
|
| 93 |
+
|
| 94 |
+
args, other_args = parser.parse_known_args()
|
| 95 |
+
|
| 96 |
+
args.decay_steps *= 0.01 * args.iterations # Calculate steps as a percent of total iterations
|
| 97 |
+
|
| 98 |
+
if args.output_video:
|
| 99 |
+
import cv2
|
| 100 |
+
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=False), minibatch_size=args.batch_size)
|
| 101 |
+
|
| 102 |
+
ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
|
| 103 |
+
ref_images = list(filter(os.path.isfile, ref_images))
|
| 104 |
+
|
| 105 |
+
if len(ref_images) == 0:
|
| 106 |
+
raise Exception('%s is empty' % args.src_dir)
|
| 107 |
+
|
| 108 |
+
os.makedirs(args.data_dir, exist_ok=True)
|
| 109 |
+
os.makedirs(args.mask_dir, exist_ok=True)
|
| 110 |
+
os.makedirs(args.generated_images_dir, exist_ok=True)
|
| 111 |
+
os.makedirs(args.dlatent_dir, exist_ok=True)
|
| 112 |
+
os.makedirs(args.video_dir, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
# Initialize generator and perceptual model
|
| 115 |
+
tflib.init_tf()
|
| 116 |
+
with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
|
| 117 |
+
generator_network, discriminator_network, Gs_network = pickle.load(f)
|
| 118 |
+
|
| 119 |
+
generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
|
| 120 |
+
if (args.dlatent_avg != ''):
|
| 121 |
+
generator.set_dlatent_avg(np.load(args.dlatent_avg))
|
| 122 |
+
|
| 123 |
+
perc_model = None
|
| 124 |
+
if (args.use_lpips_loss > 0.00000001):
|
| 125 |
+
with dnnlib.util.open_url(args.architecture, cache_dir=config.cache_dir) as f:
|
| 126 |
+
perc_model = pickle.load(f)
|
| 127 |
+
perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
|
| 128 |
+
perceptual_model.build_perceptual_model(generator, discriminator_network)
|
| 129 |
+
|
| 130 |
+
ff_model = None
|
| 131 |
+
|
| 132 |
+
# Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
|
| 133 |
+
for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size):
|
| 134 |
+
names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
|
| 135 |
+
if args.output_video:
|
| 136 |
+
video_out = {}
|
| 137 |
+
for name in names:
|
| 138 |
+
video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate, (args.video_size,args.video_size))
|
| 139 |
+
|
| 140 |
+
perceptual_model.set_reference_images(images_batch)
|
| 141 |
+
dlatents = None
|
| 142 |
+
if (args.load_last != ''): # load previous dlatents for initialization
|
| 143 |
+
for name in names:
|
| 144 |
+
dl = np.expand_dims(np.load(os.path.join(args.load_last, f'{name}.npy')),axis=0)
|
| 145 |
+
if (dlatents is None):
|
| 146 |
+
dlatents = dl
|
| 147 |
+
else:
|
| 148 |
+
dlatents = np.vstack((dlatents,dl))
|
| 149 |
+
else:
|
| 150 |
+
if (ff_model is None):
|
| 151 |
+
if os.path.exists(args.load_resnet):
|
| 152 |
+
from keras.applications.resnet50 import preprocess_input
|
| 153 |
+
print("Loading ResNet Model:")
|
| 154 |
+
ff_model = load_model(args.load_resnet)
|
| 155 |
+
if (ff_model is None):
|
| 156 |
+
if os.path.exists(args.load_effnet):
|
| 157 |
+
import efficientnet
|
| 158 |
+
from efficientnet import preprocess_input
|
| 159 |
+
print("Loading EfficientNet Model:")
|
| 160 |
+
ff_model = load_model(args.load_effnet)
|
| 161 |
+
if (ff_model is not None): # predict initial dlatents with ResNet model
|
| 162 |
+
if (args.use_preprocess_input):
|
| 163 |
+
dlatents = ff_model.predict(preprocess_input(load_images(images_batch,image_size=args.resnet_image_size)))
|
| 164 |
+
else:
|
| 165 |
+
dlatents = ff_model.predict(load_images(images_batch,image_size=args.resnet_image_size))
|
| 166 |
+
if dlatents is not None:
|
| 167 |
+
generator.set_dlatents(dlatents)
|
| 168 |
+
op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, use_optimizer=args.optimizer)
|
| 169 |
+
pbar = tqdm(op, leave=False, total=args.iterations)
|
| 170 |
+
vid_count = 0
|
| 171 |
+
best_loss = None
|
| 172 |
+
best_dlatent = None
|
| 173 |
+
avg_loss_count = 0
|
| 174 |
+
if args.early_stopping:
|
| 175 |
+
avg_loss = prev_loss = None
|
| 176 |
+
for loss_dict in pbar:
|
| 177 |
+
if args.early_stopping: # early stopping feature
|
| 178 |
+
if prev_loss is not None:
|
| 179 |
+
if avg_loss is not None:
|
| 180 |
+
avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
|
| 181 |
+
if avg_loss < args.early_stopping_threshold: # count while under threshold; else reset
|
| 182 |
+
avg_loss_count += 1
|
| 183 |
+
else:
|
| 184 |
+
avg_loss_count = 0
|
| 185 |
+
if avg_loss_count > args.early_stopping_patience: # stop once threshold is reached
|
| 186 |
+
print("")
|
| 187 |
+
break
|
| 188 |
+
else:
|
| 189 |
+
avg_loss = prev_loss - loss_dict["loss"]
|
| 190 |
+
pbar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
|
| 191 |
+
if best_loss is None or loss_dict["loss"] < best_loss:
|
| 192 |
+
if best_dlatent is None or args.average_best_loss <= 0.00000001:
|
| 193 |
+
best_dlatent = generator.get_dlatents()
|
| 194 |
+
else:
|
| 195 |
+
best_dlatent = 0.25 * best_dlatent + 0.75 * generator.get_dlatents()
|
| 196 |
+
if args.use_best_loss:
|
| 197 |
+
generator.set_dlatents(best_dlatent)
|
| 198 |
+
best_loss = loss_dict["loss"]
|
| 199 |
+
if args.output_video and (vid_count % args.video_skip == 0):
|
| 200 |
+
batch_frames = generator.generate_images()
|
| 201 |
+
for i, name in enumerate(names):
|
| 202 |
+
video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize((args.video_size,args.video_size),PIL.Image.LANCZOS)
|
| 203 |
+
video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
|
| 204 |
+
generator.stochastic_clip_dlatents()
|
| 205 |
+
prev_loss = loss_dict["loss"]
|
| 206 |
+
if not args.use_best_loss:
|
| 207 |
+
best_loss = prev_loss
|
| 208 |
+
print(" ".join(names), " Loss {:.4f}".format(best_loss))
|
| 209 |
+
|
| 210 |
+
if args.output_video:
|
| 211 |
+
for name in names:
|
| 212 |
+
video_out[name].release()
|
| 213 |
+
|
| 214 |
+
# Generate images from found dlatents and save them
|
| 215 |
+
if args.use_best_loss:
|
| 216 |
+
generator.set_dlatents(best_dlatent)
|
| 217 |
+
generated_images = generator.generate_images()
|
| 218 |
+
generated_dlatents = generator.get_dlatents()
|
| 219 |
+
for img_array, dlatent, img_path, img_name in zip(generated_images, generated_dlatents, images_batch, names):
|
| 220 |
+
mask_img = None
|
| 221 |
+
if args.composite_mask and (args.load_mask or args.face_mask):
|
| 222 |
+
_, im_name = os.path.split(img_path)
|
| 223 |
+
mask_img = os.path.join(args.mask_dir, f'{im_name}')
|
| 224 |
+
if args.composite_mask and mask_img is not None and os.path.isfile(mask_img):
|
| 225 |
+
orig_img = PIL.Image.open(img_path).convert('RGB')
|
| 226 |
+
width, height = orig_img.size
|
| 227 |
+
imask = PIL.Image.open(mask_img).convert('L').resize((width, height))
|
| 228 |
+
imask = imask.filter(ImageFilter.GaussianBlur(args.composite_blur))
|
| 229 |
+
mask = np.array(imask)/255
|
| 230 |
+
mask = np.expand_dims(mask,axis=-1)
|
| 231 |
+
img_array = mask*np.array(img_array) + (1.0-mask)*np.array(orig_img)
|
| 232 |
+
img_array = img_array.astype(np.uint8)
|
| 233 |
+
#img_array = np.where(mask, np.array(img_array), orig_img)
|
| 234 |
+
img = PIL.Image.fromarray(img_array, 'RGB')
|
| 235 |
+
img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
|
| 236 |
+
np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)
|
| 237 |
+
|
| 238 |
+
generator.reset_dlatents()
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
main()
|
encoder/__init__.py
ADDED
|
File without changes
|
encoder/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (113 Bytes). View file
|
|
|
encoder/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (117 Bytes). View file
|
|
|
encoder/__pycache__/generator_model.cpython-36.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
encoder/__pycache__/generator_model.cpython-37.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
encoder/__pycache__/perceptual_model.cpython-36.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
encoder/__pycache__/perceptual_model.cpython-37.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
encoder/generator_model.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import numpy as np
|
| 4 |
+
import dnnlib.tflib as tflib
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_stub(name, batch_size):
|
| 9 |
+
return tf.constant(0, dtype='float32', shape=(batch_size, 0))
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
|
| 13 |
+
if tiled_dlatent:
|
| 14 |
+
low_dim_dlatent = tf.get_variable('learnable_dlatents',
|
| 15 |
+
shape=(batch_size, tile_size, 512),
|
| 16 |
+
dtype='float32',
|
| 17 |
+
initializer=tf.initializers.random_normal())
|
| 18 |
+
return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
|
| 19 |
+
else:
|
| 20 |
+
return tf.get_variable('learnable_dlatents',
|
| 21 |
+
shape=(batch_size, model_scale, 512),
|
| 22 |
+
dtype='float32',
|
| 23 |
+
initializer=tf.initializers.random_normal())
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Generator:
|
| 27 |
+
def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
|
| 28 |
+
self.batch_size = batch_size
|
| 29 |
+
self.tiled_dlatent=tiled_dlatent
|
| 30 |
+
self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
|
| 31 |
+
|
| 32 |
+
if tiled_dlatent:
|
| 33 |
+
self.initial_dlatents = np.zeros((self.batch_size, 512))
|
| 34 |
+
model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
|
| 35 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
| 36 |
+
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
|
| 37 |
+
partial(create_stub, batch_size=batch_size)],
|
| 38 |
+
structure='fixed')
|
| 39 |
+
else:
|
| 40 |
+
self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
|
| 41 |
+
if custom_input is not None:
|
| 42 |
+
model.components.synthesis.run(self.initial_dlatents,
|
| 43 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
| 44 |
+
custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
|
| 45 |
+
structure='fixed')
|
| 46 |
+
else:
|
| 47 |
+
model.components.synthesis.run(self.initial_dlatents,
|
| 48 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
| 49 |
+
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
|
| 50 |
+
partial(create_stub, batch_size=batch_size)],
|
| 51 |
+
structure='fixed')
|
| 52 |
+
|
| 53 |
+
self.dlatent_avg_def = model.get_var('dlatent_avg')
|
| 54 |
+
self.reset_dlatent_avg()
|
| 55 |
+
self.sess = tf.compat.v1.get_default_session()
|
| 56 |
+
self.graph = tf.compat.v1.get_default_graph()
|
| 57 |
+
|
| 58 |
+
self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name)
|
| 59 |
+
self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph")
|
| 60 |
+
self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
|
| 61 |
+
self.set_dlatents(self.initial_dlatents)
|
| 62 |
+
|
| 63 |
+
def get_tensor(name):
|
| 64 |
+
try:
|
| 65 |
+
return self.graph.get_tensor_by_name(name)
|
| 66 |
+
except KeyError:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
|
| 70 |
+
if self.generator_output is None:
|
| 71 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
|
| 72 |
+
if self.generator_output is None:
|
| 73 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
|
| 74 |
+
# If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
|
| 75 |
+
if self.generator_output is None:
|
| 76 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
|
| 77 |
+
if self.generator_output is None:
|
| 78 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
|
| 79 |
+
if self.generator_output is None:
|
| 80 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
|
| 81 |
+
if self.generator_output is None:
|
| 82 |
+
for op in self.graph.get_operations():
|
| 83 |
+
print(op)
|
| 84 |
+
raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
|
| 85 |
+
self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
|
| 86 |
+
self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)
|
| 87 |
+
|
| 88 |
+
# Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
|
| 89 |
+
# (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
|
| 90 |
+
# so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
|
| 91 |
+
clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
|
| 92 |
+
clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
|
| 93 |
+
self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
|
| 94 |
+
|
| 95 |
+
def reset_dlatents(self):
|
| 96 |
+
self.set_dlatents(self.initial_dlatents)
|
| 97 |
+
|
| 98 |
+
def set_dlatents(self, dlatents):
|
| 99 |
+
if self.tiled_dlatent:
|
| 100 |
+
if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512):
|
| 101 |
+
dlatents = np.mean(dlatents, axis=1)
|
| 102 |
+
if (dlatents.shape != (self.batch_size, 512)):
|
| 103 |
+
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))])
|
| 104 |
+
assert (dlatents.shape == (self.batch_size, 512))
|
| 105 |
+
else:
|
| 106 |
+
if (dlatents.shape[1] > self.model_scale):
|
| 107 |
+
dlatents = dlatents[:,:self.model_scale,:]
|
| 108 |
+
if (isinstance(dlatents.shape[0], int)):
|
| 109 |
+
if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
|
| 110 |
+
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
|
| 111 |
+
assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
|
| 112 |
+
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
|
| 113 |
+
return
|
| 114 |
+
else:
|
| 115 |
+
self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
|
| 116 |
+
return
|
| 117 |
+
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
|
| 118 |
+
|
| 119 |
+
def stochastic_clip_dlatents(self):
|
| 120 |
+
self.sess.run(self.stochastic_clip_op)
|
| 121 |
+
|
| 122 |
+
def get_dlatents(self):
|
| 123 |
+
return self.sess.run(self.dlatent_variable)
|
| 124 |
+
|
| 125 |
+
def get_dlatent_avg(self):
|
| 126 |
+
return self.dlatent_avg
|
| 127 |
+
|
| 128 |
+
def set_dlatent_avg(self, dlatent_avg):
|
| 129 |
+
self.dlatent_avg = dlatent_avg
|
| 130 |
+
|
| 131 |
+
def reset_dlatent_avg(self):
|
| 132 |
+
self.dlatent_avg = self.dlatent_avg_def
|
| 133 |
+
|
| 134 |
+
def generate_images(self, dlatents=None):
|
| 135 |
+
if dlatents is not None:
|
| 136 |
+
self.set_dlatents(dlatents)
|
| 137 |
+
return self.sess.run(self.generated_image_uint8)
|
encoder/perceptual_model.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
#import tensorflow_probability as tfp
|
| 4 |
+
#tf.enable_eager_execution()
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import bz2
|
| 8 |
+
import PIL.Image
|
| 9 |
+
from PIL import ImageFilter
|
| 10 |
+
import numpy as np
|
| 11 |
+
from keras.models import Model
|
| 12 |
+
from keras.utils import get_file
|
| 13 |
+
from keras.applications.vgg16 import VGG16, preprocess_input
|
| 14 |
+
import keras.backend as K
|
| 15 |
+
import traceback
|
| 16 |
+
import dnnlib.tflib as tflib
|
| 17 |
+
|
| 18 |
+
def load_images(images_list, image_size=256, sharpen=False):
|
| 19 |
+
loaded_images = list()
|
| 20 |
+
for img_path in images_list:
|
| 21 |
+
img = PIL.Image.open(img_path).convert('RGB')
|
| 22 |
+
if image_size is not None:
|
| 23 |
+
img = img.resize((image_size,image_size),PIL.Image.LANCZOS)
|
| 24 |
+
if (sharpen):
|
| 25 |
+
img = img.filter(ImageFilter.DETAIL)
|
| 26 |
+
img = np.array(img)
|
| 27 |
+
img = np.expand_dims(img, 0)
|
| 28 |
+
loaded_images.append(img)
|
| 29 |
+
loaded_images = np.vstack(loaded_images)
|
| 30 |
+
return loaded_images
|
| 31 |
+
|
| 32 |
+
def tf_custom_adaptive_loss(a,b):
|
| 33 |
+
from adaptive import lossfun
|
| 34 |
+
shape = a.get_shape().as_list()
|
| 35 |
+
dim = np.prod(shape[1:])
|
| 36 |
+
a = tf.reshape(a, [-1, dim])
|
| 37 |
+
b = tf.reshape(b, [-1, dim])
|
| 38 |
+
loss, _, _ = lossfun(b-a, var_suffix='1')
|
| 39 |
+
return tf.math.reduce_mean(loss)
|
| 40 |
+
|
| 41 |
+
def tf_custom_adaptive_rgb_loss(a,b):
|
| 42 |
+
from adaptive import image_lossfun
|
| 43 |
+
loss, _, _ = image_lossfun(b-a, color_space='RGB', representation='PIXEL')
|
| 44 |
+
return tf.math.reduce_mean(loss)
|
| 45 |
+
|
| 46 |
+
def tf_custom_l1_loss(img1,img2):
|
| 47 |
+
return tf.math.reduce_mean(tf.math.abs(img2-img1), axis=None)
|
| 48 |
+
|
| 49 |
+
def tf_custom_logcosh_loss(img1,img2):
|
| 50 |
+
return tf.math.reduce_mean(tf.keras.losses.logcosh(img1,img2))
|
| 51 |
+
|
| 52 |
+
def create_stub(batch_size):
|
| 53 |
+
return tf.constant(0, dtype='float32', shape=(batch_size, 0))
|
| 54 |
+
|
| 55 |
+
def unpack_bz2(src_path):
|
| 56 |
+
data = bz2.BZ2File(src_path).read()
|
| 57 |
+
dst_path = src_path[:-4]
|
| 58 |
+
with open(dst_path, 'wb') as fp:
|
| 59 |
+
fp.write(data)
|
| 60 |
+
return dst_path
|
| 61 |
+
|
| 62 |
+
class PerceptualModel:
|
| 63 |
+
def __init__(self, args, batch_size=1, perc_model=None, sess=None):
|
| 64 |
+
self.sess = tf.compat.v1.get_default_session() if sess is None else sess
|
| 65 |
+
K.set_session(self.sess)
|
| 66 |
+
self.epsilon = 0.00000001
|
| 67 |
+
self.lr = args.lr
|
| 68 |
+
self.decay_rate = args.decay_rate
|
| 69 |
+
self.decay_steps = args.decay_steps
|
| 70 |
+
self.img_size = args.image_size
|
| 71 |
+
self.layer = args.use_vgg_layer
|
| 72 |
+
self.vgg_loss = args.use_vgg_loss
|
| 73 |
+
self.face_mask = args.face_mask
|
| 74 |
+
self.use_grabcut = args.use_grabcut
|
| 75 |
+
self.scale_mask = args.scale_mask
|
| 76 |
+
self.mask_dir = args.mask_dir
|
| 77 |
+
if (self.layer <= 0 or self.vgg_loss <= self.epsilon):
|
| 78 |
+
self.vgg_loss = None
|
| 79 |
+
self.pixel_loss = args.use_pixel_loss
|
| 80 |
+
if (self.pixel_loss <= self.epsilon):
|
| 81 |
+
self.pixel_loss = None
|
| 82 |
+
self.mssim_loss = args.use_mssim_loss
|
| 83 |
+
if (self.mssim_loss <= self.epsilon):
|
| 84 |
+
self.mssim_loss = None
|
| 85 |
+
self.lpips_loss = args.use_lpips_loss
|
| 86 |
+
if (self.lpips_loss <= self.epsilon):
|
| 87 |
+
self.lpips_loss = None
|
| 88 |
+
self.l1_penalty = args.use_l1_penalty
|
| 89 |
+
if (self.l1_penalty <= self.epsilon):
|
| 90 |
+
self.l1_penalty = None
|
| 91 |
+
self.adaptive_loss = args.use_adaptive_loss
|
| 92 |
+
self.sharpen_input = args.sharpen_input
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
if perc_model is not None and self.lpips_loss is not None:
|
| 95 |
+
self.perc_model = perc_model
|
| 96 |
+
else:
|
| 97 |
+
self.perc_model = None
|
| 98 |
+
self.ref_img = None
|
| 99 |
+
self.ref_weight = None
|
| 100 |
+
self.perceptual_model = None
|
| 101 |
+
self.ref_img_features = None
|
| 102 |
+
self.features_weight = None
|
| 103 |
+
self.loss = None
|
| 104 |
+
self.discriminator_loss = args.use_discriminator_loss
|
| 105 |
+
if (self.discriminator_loss <= self.epsilon):
|
| 106 |
+
self.discriminator_loss = None
|
| 107 |
+
if self.discriminator_loss is not None:
|
| 108 |
+
self.discriminator = None
|
| 109 |
+
self.stub = create_stub(batch_size)
|
| 110 |
+
|
| 111 |
+
if self.face_mask:
|
| 112 |
+
import dlib
|
| 113 |
+
self.detector = dlib.get_frontal_face_detector()
|
| 114 |
+
landmarks_model_path = unpack_bz2('shape_predictor_68_face_landmarks.dat.bz2')
|
| 115 |
+
self.predictor = dlib.shape_predictor(landmarks_model_path)
|
| 116 |
+
|
| 117 |
+
def add_placeholder(self, var_name):
|
| 118 |
+
var_val = getattr(self, var_name)
|
| 119 |
+
setattr(self, var_name + "_placeholder", tf.compat.v1.placeholder(var_val.dtype, shape=var_val.get_shape()))
|
| 120 |
+
setattr(self, var_name + "_op", var_val.assign(getattr(self, var_name + "_placeholder")))
|
| 121 |
+
|
| 122 |
+
def assign_placeholder(self, var_name, var_val):
|
| 123 |
+
self.sess.run(getattr(self, var_name + "_op"), {getattr(self, var_name + "_placeholder"): var_val})
|
| 124 |
+
|
| 125 |
+
def build_perceptual_model(self, generator, discriminator=None):
|
| 126 |
+
# Learning rate
|
| 127 |
+
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
|
| 128 |
+
incremented_global_step = tf.compat.v1.assign_add(global_step, 1)
|
| 129 |
+
self._reset_global_step = tf.assign(global_step, 0)
|
| 130 |
+
self.learning_rate = tf.compat.v1.train.exponential_decay(self.lr, incremented_global_step,
|
| 131 |
+
self.decay_steps, self.decay_rate, staircase=True)
|
| 132 |
+
self.sess.run([self._reset_global_step])
|
| 133 |
+
|
| 134 |
+
if self.discriminator_loss is not None:
|
| 135 |
+
self.discriminator = discriminator
|
| 136 |
+
|
| 137 |
+
generated_image_tensor = generator.generated_image
|
| 138 |
+
generated_image = tf.compat.v1.image.resize_nearest_neighbor(generated_image_tensor,
|
| 139 |
+
(self.img_size, self.img_size), align_corners=True)
|
| 140 |
+
|
| 141 |
+
self.ref_img = tf.get_variable('ref_img', shape=generated_image.shape,
|
| 142 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
| 143 |
+
self.ref_weight = tf.get_variable('ref_weight', shape=generated_image.shape,
|
| 144 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
| 145 |
+
self.add_placeholder("ref_img")
|
| 146 |
+
self.add_placeholder("ref_weight")
|
| 147 |
+
|
| 148 |
+
if (self.vgg_loss is not None):
|
| 149 |
+
vgg16 = VGG16(include_top=False, weights='vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', input_shape=(self.img_size, self.img_size, 3)) # https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
|
| 150 |
+
self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
|
| 151 |
+
generated_img_features = self.perceptual_model(preprocess_input(self.ref_weight * generated_image))
|
| 152 |
+
self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
|
| 153 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
| 154 |
+
self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
|
| 155 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
| 156 |
+
self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
|
| 157 |
+
self.add_placeholder("ref_img_features")
|
| 158 |
+
self.add_placeholder("features_weight")
|
| 159 |
+
|
| 160 |
+
if self.perc_model is not None and self.lpips_loss is not None:
|
| 161 |
+
img1 = tflib.convert_images_from_uint8(self.ref_weight * self.ref_img, nhwc_to_nchw=True)
|
| 162 |
+
img2 = tflib.convert_images_from_uint8(self.ref_weight * generated_image, nhwc_to_nchw=True)
|
| 163 |
+
|
| 164 |
+
self.loss = 0
|
| 165 |
+
# L1 loss on VGG16 features
|
| 166 |
+
if (self.vgg_loss is not None):
|
| 167 |
+
if self.adaptive_loss:
|
| 168 |
+
self.loss += self.vgg_loss * tf_custom_adaptive_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
|
| 169 |
+
else:
|
| 170 |
+
self.loss += self.vgg_loss * tf_custom_logcosh_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
|
| 171 |
+
# + logcosh loss on image pixels
|
| 172 |
+
if (self.pixel_loss is not None):
|
| 173 |
+
if self.adaptive_loss:
|
| 174 |
+
self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
|
| 175 |
+
else:
|
| 176 |
+
self.loss += self.pixel_loss * tf_custom_logcosh_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
|
| 177 |
+
# + MS-SIM loss on image pixels
|
| 178 |
+
if (self.mssim_loss is not None):
|
| 179 |
+
self.loss += self.mssim_loss * tf.math.reduce_mean(1-tf.image.ssim_multiscale(self.ref_weight * self.ref_img, self.ref_weight * generated_image, 1))
|
| 180 |
+
# + extra perceptual loss on image pixels
|
| 181 |
+
if self.perc_model is not None and self.lpips_loss is not None:
|
| 182 |
+
self.loss += self.lpips_loss * tf.math.reduce_mean(self.perc_model.get_output_for(img1, img2))
|
| 183 |
+
# + L1 penalty on dlatent weights
|
| 184 |
+
if self.l1_penalty is not None:
|
| 185 |
+
self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(tf.math.abs(generator.dlatent_variable-generator.get_dlatent_avg()))
|
| 186 |
+
# discriminator loss (realism)
|
| 187 |
+
if self.discriminator_loss is not None:
|
| 188 |
+
self.loss += self.discriminator_loss * tf.math.reduce_mean(self.discriminator.get_output_for(tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub))
|
| 189 |
+
# - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def generate_face_mask(self, im):
|
| 193 |
+
from imutils import face_utils
|
| 194 |
+
import cv2
|
| 195 |
+
rects = self.detector(im, 1)
|
| 196 |
+
# loop over the face detections
|
| 197 |
+
for (j, rect) in enumerate(rects):
|
| 198 |
+
"""
|
| 199 |
+
Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array
|
| 200 |
+
"""
|
| 201 |
+
shape = self.predictor(im, rect)
|
| 202 |
+
shape = face_utils.shape_to_np(shape)
|
| 203 |
+
|
| 204 |
+
# we extract the face
|
| 205 |
+
vertices = cv2.convexHull(shape)
|
| 206 |
+
mask = np.zeros(im.shape[:2],np.uint8)
|
| 207 |
+
cv2.fillConvexPoly(mask, vertices, 1)
|
| 208 |
+
if self.use_grabcut:
|
| 209 |
+
bgdModel = np.zeros((1,65),np.float64)
|
| 210 |
+
fgdModel = np.zeros((1,65),np.float64)
|
| 211 |
+
rect = (0,0,im.shape[1],im.shape[2])
|
| 212 |
+
(x,y),radius = cv2.minEnclosingCircle(vertices)
|
| 213 |
+
center = (int(x),int(y))
|
| 214 |
+
radius = int(radius*self.scale_mask)
|
| 215 |
+
mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1)
|
| 216 |
+
cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD)
|
| 217 |
+
cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK)
|
| 218 |
+
mask = np.where((mask==2)|(mask==0),0,1)
|
| 219 |
+
return mask
|
| 220 |
+
|
| 221 |
+
def set_reference_images(self, images_list):
|
| 222 |
+
assert(len(images_list) != 0 and len(images_list) <= self.batch_size)
|
| 223 |
+
loaded_image = load_images(images_list, self.img_size, sharpen=self.sharpen_input)
|
| 224 |
+
image_features = None
|
| 225 |
+
if self.perceptual_model is not None:
|
| 226 |
+
image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image)))
|
| 227 |
+
weight_mask = np.ones(self.features_weight.shape)
|
| 228 |
+
|
| 229 |
+
if self.face_mask:
|
| 230 |
+
image_mask = np.zeros(self.ref_weight.shape)
|
| 231 |
+
for (i, im) in enumerate(loaded_image):
|
| 232 |
+
try:
|
| 233 |
+
_, img_name = os.path.split(images_list[i])
|
| 234 |
+
mask_img = os.path.join(self.mask_dir, f'{img_name}')
|
| 235 |
+
if (os.path.isfile(mask_img)):
|
| 236 |
+
print("Loading mask " + mask_img)
|
| 237 |
+
imask = PIL.Image.open(mask_img).convert('L')
|
| 238 |
+
mask = np.array(imask)/255
|
| 239 |
+
mask = np.expand_dims(mask,axis=-1)
|
| 240 |
+
else:
|
| 241 |
+
mask = self.generate_face_mask(im)
|
| 242 |
+
imask = (255*mask).astype('uint8')
|
| 243 |
+
imask = PIL.Image.fromarray(imask, 'L')
|
| 244 |
+
print("Saving mask " + mask_img)
|
| 245 |
+
imask.save(mask_img, 'PNG')
|
| 246 |
+
mask = np.expand_dims(mask,axis=-1)
|
| 247 |
+
mask = np.ones(im.shape,np.float32) * mask
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print("Exception in mask handling for " + mask_img)
|
| 250 |
+
traceback.print_exc()
|
| 251 |
+
mask = np.ones(im.shape[:2],np.uint8)
|
| 252 |
+
mask = np.ones(im.shape,np.float32) * np.expand_dims(mask,axis=-1)
|
| 253 |
+
image_mask[i] = mask
|
| 254 |
+
img = None
|
| 255 |
+
else:
|
| 256 |
+
image_mask = np.ones(self.ref_weight.shape)
|
| 257 |
+
|
| 258 |
+
if len(images_list) != self.batch_size:
|
| 259 |
+
if image_features is not None:
|
| 260 |
+
features_space = list(self.features_weight.shape[1:])
|
| 261 |
+
existing_features_shape = [len(images_list)] + features_space
|
| 262 |
+
empty_features_shape = [self.batch_size - len(images_list)] + features_space
|
| 263 |
+
existing_examples = np.ones(shape=existing_features_shape)
|
| 264 |
+
empty_examples = np.zeros(shape=empty_features_shape)
|
| 265 |
+
weight_mask = np.vstack([existing_examples, empty_examples])
|
| 266 |
+
image_features = np.vstack([image_features, np.zeros(empty_features_shape)])
|
| 267 |
+
|
| 268 |
+
images_space = list(self.ref_weight.shape[1:])
|
| 269 |
+
existing_images_space = [len(images_list)] + images_space
|
| 270 |
+
empty_images_space = [self.batch_size - len(images_list)] + images_space
|
| 271 |
+
existing_images = np.ones(shape=existing_images_space)
|
| 272 |
+
empty_images = np.zeros(shape=empty_images_space)
|
| 273 |
+
image_mask = image_mask * np.vstack([existing_images, empty_images])
|
| 274 |
+
loaded_image = np.vstack([loaded_image, np.zeros(empty_images_space)])
|
| 275 |
+
|
| 276 |
+
if image_features is not None:
|
| 277 |
+
self.assign_placeholder("features_weight", weight_mask)
|
| 278 |
+
self.assign_placeholder("ref_img_features", image_features)
|
| 279 |
+
self.assign_placeholder("ref_weight", image_mask)
|
| 280 |
+
self.assign_placeholder("ref_img", loaded_image)
|
| 281 |
+
|
| 282 |
+
def optimize(self, vars_to_optimize, iterations=200, use_optimizer='adam'):
|
| 283 |
+
vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
|
| 284 |
+
if use_optimizer == 'lbfgs':
|
| 285 |
+
optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.loss, var_list=vars_to_optimize, method='L-BFGS-B', options={'maxiter': iterations})
|
| 286 |
+
else:
|
| 287 |
+
if use_optimizer == 'ggt':
|
| 288 |
+
optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=self.learning_rate)
|
| 289 |
+
else:
|
| 290 |
+
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
| 291 |
+
min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
|
| 292 |
+
self.sess.run(tf.variables_initializer(optimizer.variables()))
|
| 293 |
+
fetch_ops = [min_op, self.loss, self.learning_rate]
|
| 294 |
+
#min_op = optimizer.minimize(self.sess)
|
| 295 |
+
#optim_results = tfp.optimizer.lbfgs_minimize(make_val_and_grad_fn(get_loss), initial_position=vars_to_optimize, num_correction_pairs=10, tolerance=1e-8)
|
| 296 |
+
self.sess.run(self._reset_global_step)
|
| 297 |
+
#self.sess.graph.finalize() # Graph is read-only after this statement.
|
| 298 |
+
for _ in range(iterations):
|
| 299 |
+
if use_optimizer == 'lbfgs':
|
| 300 |
+
optimizer.minimize(self.sess, fetches=[vars_to_optimize, self.loss])
|
| 301 |
+
yield {"loss":self.loss.eval()}
|
| 302 |
+
else:
|
| 303 |
+
_, loss, lr = self.sess.run(fetch_ops)
|
| 304 |
+
yield {"loss":loss,"lr":lr}
|
ffhq_dataset/__init__.py
ADDED
|
File without changes
|
ffhq_dataset/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (118 Bytes). View file
|
|
|
ffhq_dataset/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (122 Bytes). View file
|
|
|
ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|