ZJW666 commited on
Commit
7a59a55
·
1 Parent(s): 774a6a3

fist version

Browse files
Files changed (46) hide show
  1. app.py +130 -0
  2. dnnlib/__init__.py +9 -0
  3. dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  4. dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  5. dnnlib/util.py +491 -0
  6. feature_networks/__pycache__/constants.cpython-39.pyc +0 -0
  7. feature_networks/__pycache__/pretrained_builder.cpython-39.pyc +0 -0
  8. feature_networks/__pycache__/vit.cpython-39.pyc +0 -0
  9. feature_networks/clip/__init__.py +1 -0
  10. feature_networks/clip/__pycache__/__init__.cpython-39.pyc +0 -0
  11. feature_networks/clip/__pycache__/clip.cpython-39.pyc +0 -0
  12. feature_networks/clip/__pycache__/model.cpython-39.pyc +0 -0
  13. feature_networks/clip/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
  14. feature_networks/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  15. feature_networks/clip/clip.py +244 -0
  16. feature_networks/clip/model.py +453 -0
  17. feature_networks/clip/simple_tokenizer.py +132 -0
  18. feature_networks/constants.py +129 -0
  19. feature_networks/pretrained_builder.py +417 -0
  20. feature_networks/vit.py +436 -0
  21. legacy.py +331 -0
  22. misc.py +275 -0
  23. pg_modules/__init__.py +0 -0
  24. pg_modules/__pycache__/MViT.cpython-39.pyc +0 -0
  25. pg_modules/__pycache__/__init__.cpython-39.pyc +0 -0
  26. pg_modules/__pycache__/blocks.cpython-38.pyc +0 -0
  27. pg_modules/__pycache__/blocks.cpython-39.pyc +0 -0
  28. pg_modules/__pycache__/diffaug.cpython-38.pyc +0 -0
  29. pg_modules/__pycache__/diffaug.cpython-39.pyc +0 -0
  30. pg_modules/__pycache__/discriminator.cpython-38.pyc +0 -0
  31. pg_modules/__pycache__/discriminator.cpython-39.pyc +0 -0
  32. pg_modules/__pycache__/mae.cpython-39.pyc +0 -0
  33. pg_modules/__pycache__/models_tnt.cpython-39.pyc +0 -0
  34. pg_modules/__pycache__/networks_fastgan.cpython-38.pyc +0 -0
  35. pg_modules/__pycache__/networks_fastgan.cpython-39.pyc +0 -0
  36. pg_modules/__pycache__/networks_stylegan2.cpython-39.pyc +0 -0
  37. pg_modules/__pycache__/projector.cpython-38.pyc +0 -0
  38. pg_modules/__pycache__/projector.cpython-39.pyc +0 -0
  39. pg_modules/__pycache__/simmim.cpython-39.pyc +0 -0
  40. pg_modules/__pycache__/vision_transformer.cpython-39.pyc +0 -0
  41. pg_modules/blocks.py +370 -0
  42. pg_modules/diffaug.py +76 -0
  43. pg_modules/discriminator.py +153 -0
  44. pg_modules/networks_fastgan.py +180 -0
  45. pg_modules/networks_stylegan2.py +537 -0
  46. pg_modules/projector.py +158 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ """Generate images using pretrained network pickle."""
7
+
8
+ import re
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import click
12
+ import dnnlib
13
+ import numpy as np
14
+ import PIL.Image
15
+ import torch
16
+
17
+ import legacy
18
+
19
+ from huggingface_hub import hf_hub_url
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def parse_range(s: Union[str, List]) -> List[int]:
24
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
25
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
26
+ '''
27
+ if isinstance(s, list): return s
28
+ ranges = []
29
+ range_re = re.compile(r'^(\d+)-(\d+)$')
30
+ for p in s.split(','):
31
+ m = range_re.match(p)
32
+ if m:
33
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
34
+ else:
35
+ ranges.append(int(p))
36
+ return ranges
37
+
38
+ #----------------------------------------------------------------------------
39
+
40
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
41
+ '''Parse a floating point 2-vector of syntax 'a,b'.
42
+ Example:
43
+ '0,1' returns (0,1)
44
+ '''
45
+ if isinstance(s, tuple): return s
46
+ parts = s.split(',')
47
+ if len(parts) == 2:
48
+ return (float(parts[0]), float(parts[1]))
49
+ raise ValueError(f'cannot parse 2-vector {s}')
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ def make_transform(translate: Tuple[float,float], angle: float):
54
+ m = np.eye(3)
55
+ s = np.sin(angle/360.0*np.pi*2)
56
+ c = np.cos(angle/360.0*np.pi*2)
57
+ m[0][0] = c
58
+ m[0][1] = s
59
+ m[0][2] = translate[0]
60
+ m[1][0] = -s
61
+ m[1][1] = c
62
+ m[1][2] = translate[1]
63
+ return m
64
+
65
+ #----------------------------------------------------------------------------
66
+
67
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
68
+
69
+ # config_file_url = hf_hub_url("autonomousvision/Projected_GAN_Pokemon", filename="pokemon.pkl")
70
+ # config_file_url = r'E:\桌面\Preparation of Papers for IEEE Signal Processing Letters (5-page limit)\codes\pokemon.pkl'
71
+ # with dnnlib.util.open_url(config_file_url) as f:
72
+ # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
73
+
74
+ # models = {
75
+ # 'pokemon':
76
+ # }
77
+ # base_path =
78
+ models = dict()
79
+ for i in ["pokemon", "art-paint", "flowers", "landscapes","obama"]:
80
+ with dnnlib.util.open_url("E:\桌面\Preparation of Papers for IEEE Signal Processing Letters (5-page limit)\codes\projected-gan-clc - 副本\\" +i+".pkl") as f:
81
+ models[i] = legacy.load_network_pkl(f)['G_ema']
82
+
83
+
84
+ def generate_images(seeds, name):
85
+ """Generate images using pretrained network pickle.
86
+ Examples:
87
+ \b
88
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
89
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
90
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
91
+ \b
92
+ # Generate uncurated images with truncation using the MetFaces-U dataset
93
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
94
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
95
+ """
96
+ # models
97
+ G = models[name].to(device)
98
+ # Labels.
99
+ label = torch.zeros([1, G.c_dim], device=device)
100
+
101
+ # Generate images.
102
+ for seed_idx, seed in enumerate(seeds):
103
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
104
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float()
105
+
106
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
107
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
108
+ # operations in the network.
109
+ if hasattr(G.synthesis, 'input'):
110
+ m = make_transform('0,0', 0)
111
+ m = np.linalg.inv(m)
112
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
113
+
114
+ img = G(z, label, truncation_psi=1, noise_mode='const')
115
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
116
+ pilimg = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
117
+ return pilimg
118
+
119
+
120
+ def inference(seedin, name = None):
121
+ print(name)
122
+ listseed = [int(seedin)]
123
+ output = generate_images(listseed, name)
124
+ return output
125
+
126
+ title = "Projected GAN CLC"
127
+ description = "Gradio demo for Projected GANs CLC, Pokemon."
128
+
129
+ gr.Interface(fn=inference,inputs=[gr.Slider(label="Seed",minimum=0, maximum=5000, step=1, value=0), gr.Radio(["pokemon", "art-paint", "flowers", "landscapes","obama"], label='Dataset', value='art-paint')],outputs=["image"],title=title,description=description
130
+ ).launch()
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (285 Bytes). View file
 
dnnlib/__pycache__/util.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
feature_networks/__pycache__/constants.cpython-39.pyc ADDED
Binary file (2.06 kB). View file
 
feature_networks/__pycache__/pretrained_builder.cpython-39.pyc ADDED
Binary file (8.5 kB). View file
 
feature_networks/__pycache__/vit.cpython-39.pyc ADDED
Binary file (8.58 kB). View file
 
feature_networks/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
feature_networks/clip/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (254 Bytes). View file
 
feature_networks/clip/__pycache__/clip.cpython-39.pyc ADDED
Binary file (9.19 kB). View file
 
feature_networks/clip/__pycache__/model.cpython-39.pyc ADDED
Binary file (15.4 kB). View file
 
feature_networks/clip/__pycache__/simple_tokenizer.cpython-39.pyc ADDED
Binary file (5.84 kB). View file
 
feature_networks/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
feature_networks/clip/clip.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ __all__ = ["available_models", "load", "tokenize"]
17
+ _tokenizer = _Tokenizer()
18
+
19
+ _MODELS = {
20
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
21
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
23
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
24
+ }
25
+
26
+
27
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
28
+ os.makedirs(root, exist_ok=True)
29
+ filename = os.path.basename(url)
30
+
31
+ expected_sha256 = url.split("/")[-2]
32
+ download_target = os.path.join(root, filename)
33
+
34
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
35
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
36
+
37
+ if os.path.isfile(download_target):
38
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
39
+ return download_target
40
+ else:
41
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
42
+
43
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
44
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
45
+ while True:
46
+ buffer = source.read(8192)
47
+ if not buffer:
48
+ break
49
+
50
+ output.write(buffer)
51
+ loop.update(len(buffer))
52
+
53
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
54
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
55
+
56
+ return download_target
57
+
58
+
59
+ def _transform(n_px):
60
+ return Compose([
61
+ Resize(n_px, interpolation=Image.BICUBIC),
62
+ CenterCrop(n_px),
63
+ lambda image: image.convert("RGB"),
64
+ ToTensor(),
65
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
66
+ ])
67
+
68
+
69
+ def available_models() -> List[str]:
70
+ """Returns the names of available CLIP models"""
71
+ return list(_MODELS.keys())
72
+
73
+
74
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
75
+ """Load a CLIP model
76
+
77
+ Parameters
78
+ ----------
79
+ name : str
80
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
81
+
82
+ device : Union[str, torch.device]
83
+ The device to put the loaded model
84
+
85
+ jit : bool
86
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
87
+
88
+ Returns
89
+ -------
90
+ model : torch.nn.Module
91
+ The CLIP model
92
+
93
+ preprocess : Callable[[PIL.Image], torch.Tensor]
94
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
95
+ """
96
+ if name in _MODELS:
97
+ model_path = _download(_MODELS[name])
98
+ elif os.path.isfile(name):
99
+ model_path = name
100
+ else:
101
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
102
+
103
+ try:
104
+ # loading JIT archive
105
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
106
+ state_dict = None
107
+ except RuntimeError:
108
+ # loading saved state dict
109
+ if jit:
110
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
111
+ jit = False
112
+ state_dict = torch.load(model_path, map_location="cpu")
113
+
114
+ if not jit:
115
+ model = build_model(state_dict or model.state_dict()).to(device)
116
+ if str(device) == "cpu":
117
+ model.float()
118
+ return model, _transform(model.visual.input_resolution)
119
+
120
+ # patch the device names
121
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
122
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
123
+
124
+ def patch_device(module):
125
+ graphs = [module.graph] if hasattr(module, "graph") else []
126
+ if hasattr(module, "forward1"):
127
+ graphs.append(module.forward1.graph)
128
+
129
+ for graph in graphs:
130
+ for node in graph.findAllNodes("prim::Constant"):
131
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
132
+ node.copyAttributes(device_node)
133
+
134
+ model.apply(patch_device)
135
+ patch_device(model.encode_image)
136
+ patch_device(model.encode_text)
137
+
138
+ # patch dtype to float32 on CPU
139
+ if str(device) == "cpu":
140
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
141
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
142
+ float_node = float_input.node()
143
+
144
+ def patch_float(module):
145
+ graphs = [module.graph] if hasattr(module, "graph") else []
146
+ if hasattr(module, "forward1"):
147
+ graphs.append(module.forward1.graph)
148
+
149
+ for graph in graphs:
150
+ for node in graph.findAllNodes("aten::to"):
151
+ inputs = list(node.inputs())
152
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
153
+ if inputs[i].node()["value"] == 5:
154
+ inputs[i].node().copyAttributes(float_node)
155
+
156
+ model.apply(patch_float)
157
+ patch_float(model.encode_image)
158
+ patch_float(model.encode_text)
159
+
160
+ model.float()
161
+
162
+ return model, _transform(model.input_resolution.item())
163
+
164
+
165
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
166
+ """
167
+ Returns the tokenized representation of given input string(s)
168
+
169
+ Parameters
170
+ ----------
171
+ texts : Union[str, List[str]]
172
+ An input string or a list of input strings to tokenize
173
+
174
+ context_length : int
175
+ The context length to use; all CLIP models use 77 as the context length
176
+
177
+ Returns
178
+ -------
179
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
180
+ """
181
+ if isinstance(texts, str):
182
+ texts = [texts]
183
+
184
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
185
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
186
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
187
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
188
+
189
+ for i, tokens in enumerate(all_tokens):
190
+ if len(tokens) > context_length:
191
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
192
+ result[i, :len(tokens)] = torch.tensor(tokens)
193
+
194
+ return result
195
+
196
+ def pdist(sample_1, sample_2, norm=2, eps=1e-5):
197
+ r"""Compute the matrix of all squared pairwise distances.
198
+ Arguments
199
+ ---------
200
+ sample_1 : torch.Tensor or Variable
201
+ The first sample, should be of shape ``(n_1, d)``.
202
+ sample_2 : torch.Tensor or Variable
203
+ The second sample, should be of shape ``(n_2, d)``.
204
+ norm : float
205
+ The l_p norm to be used.
206
+ Returns
207
+ -------
208
+ torch.Tensor or Variable
209
+ Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
210
+ ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
211
+ n_1, n_2 = sample_1.size(0), sample_2.size(0)
212
+ norm = float(norm)
213
+ if norm == 2.:
214
+ norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
215
+ norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
216
+ norms = (norms_1.expand(n_1, n_2) +
217
+ norms_2.transpose(0, 1).expand(n_1, n_2))
218
+ distances_squared = norms - 2 * sample_1.mm(sample_2.t())
219
+ return torch.sqrt(eps + torch.abs(distances_squared))
220
+ else:
221
+ dim = sample_1.size(1)
222
+ expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
223
+ expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
224
+ differences = torch.abs(expanded_1 - expanded_2) ** norm
225
+ inner = torch.sum(differences, dim=2, keepdim=False)
226
+ return (eps + inner) ** (1. / norm)
227
+
228
+
229
+ class ClipHead(nn.Module):
230
+ def __init__(self, prompt, device='cpu'):
231
+ super().__init__()
232
+ self.clip_model = load("RN50", device=device, jit=False)[0].eval()
233
+ self.prompt = prompt
234
+
235
+ def calc_loss(self, features):
236
+ dev = features['last'].get_device()
237
+ text_input = tokenize(self.prompt).to(dev)
238
+
239
+ text_features = self.clip_model.encode_text(text_input)
240
+ image_features = self.clip_model.encode_conv_features(features['last'])
241
+ loss = - torch.cosine_similarity(text_features, image_features, dim=1)
242
+ # loss -= (pdist(image_features, image_features)/image_features.max()).sum()
243
+
244
+ return loss.mean()
feature_networks/clip/model.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(OrderedDict([
35
+ ("-1", nn.AvgPool2d(stride)),
36
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
+ ("1", nn.BatchNorm2d(planes * self.expansion))
38
+ ]))
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ identity = x
42
+
43
+ out = self.relu(self.bn1(self.conv1(x)))
44
+ out = self.relu(self.bn2(self.conv2(out)))
45
+ out = self.avgpool(out)
46
+ out = self.bn3(self.conv3(out))
47
+
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+ return out
54
+
55
+
56
+ class AttentionPool2d(nn.Module):
57
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
+ super().__init__()
59
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def forward(self, x):
67
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
+ self.bn2 = nn.BatchNorm2d(width // 2)
111
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
+ self.bn3 = nn.BatchNorm2d(width)
113
+ self.avgpool = nn.AvgPool2d(2)
114
+ self.relu = nn.ReLU(inplace=True)
115
+
116
+ # residual layers
117
+ self._inplanes = width # this is a *mutable* variable used during construction
118
+ self.layer1 = self._make_layer(width, layers[0])
119
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
+
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
+
126
+ def _make_layer(self, planes, blocks, stride=1):
127
+ layers = [Bottleneck(self._inplanes, planes, stride)]
128
+
129
+ self._inplanes = planes * Bottleneck.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(Bottleneck(self._inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ def stem(x):
137
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
+ x = self.relu(bn(conv(x)))
139
+ x = self.avgpool(x)
140
+ return x
141
+
142
+ x = x.type(self.conv1.weight.dtype)
143
+ x = stem(x)
144
+ x = self.layer1(x)
145
+ x = self.layer2(x)
146
+ x = self.layer3(x)
147
+ x = self.layer4(x)
148
+ x = self.attnpool(x)
149
+
150
+ return x
151
+
152
+
153
+ class LayerNorm(nn.LayerNorm):
154
+ """Subclass torch's LayerNorm to handle fp16."""
155
+
156
+ def forward(self, x: torch.Tensor):
157
+ orig_type = x.dtype
158
+ ret = super().forward(x.type(torch.float32))
159
+ return ret.type(orig_type)
160
+
161
+
162
+ class QuickGELU(nn.Module):
163
+ def forward(self, x: torch.Tensor):
164
+ return x * torch.sigmoid(1.702 * x)
165
+
166
+
167
+ class ResidualAttentionBlock(nn.Module):
168
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
+ super().__init__()
170
+
171
+ self.attn = nn.MultiheadAttention(d_model, n_head)
172
+ self.ln_1 = LayerNorm(d_model)
173
+ self.mlp = nn.Sequential(OrderedDict([
174
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
175
+ ("gelu", QuickGELU()),
176
+ ("c_proj", nn.Linear(d_model * 4, d_model))
177
+ ]))
178
+ self.ln_2 = LayerNorm(d_model)
179
+ self.attn_mask = attn_mask
180
+
181
+ def attention(self, x: torch.Tensor):
182
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ x = x + self.attention(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisualTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 0, :])
232
+
233
+ if self.proj is not None:
234
+ x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisualTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return x
353
+
354
+ def encode_conv_features(self, features):
355
+ # pool to 7, the feature map resolution for 224x224 input
356
+ features = nn.AdaptiveAvgPool2d(7)(features)
357
+ return self.visual.attnpool(features)
358
+
359
+ def forward(self, image, text):
360
+ image_features = self.encode_image(image)
361
+ text_features = self.encode_text(text)
362
+
363
+ # normalized features
364
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
365
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
366
+
367
+ # cosine similarity as logits
368
+ logit_scale = self.logit_scale.exp()
369
+ logits_per_image = logit_scale * image_features @ text_features.t()
370
+ logits_per_text = logit_scale * text_features @ image_features.t()
371
+
372
+ # shape = [global_batch_size, global_batch_size]
373
+ return logits_per_image, logits_per_text
374
+
375
+ def forward_features(self, features, text):
376
+ image_features = self.encode_conv_features(features)
377
+ text_features = self.encode_text(text)
378
+
379
+ # normalized features
380
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
381
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
382
+
383
+ # cosine similarity as logits
384
+ logit_scale = self.logit_scale.exp()
385
+ logits_per_image = logit_scale * image_features @ text_features.t()
386
+ logits_per_text = logit_scale * text_features @ image_features.t()
387
+
388
+ # shape = [global_batch_size, global_batch_size]
389
+ return logits_per_image, logits_per_text
390
+
391
+
392
+ def convert_weights(model: nn.Module):
393
+ """Convert applicable model parameters to fp16"""
394
+
395
+ def _convert_weights_to_fp16(l):
396
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
397
+ l.weight.data = l.weight.data.half()
398
+ if l.bias is not None:
399
+ l.bias.data = l.bias.data.half()
400
+
401
+ if isinstance(l, nn.MultiheadAttention):
402
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
403
+ tensor = getattr(l, attr)
404
+ if tensor is not None:
405
+ tensor.data = tensor.data.half()
406
+
407
+ for name in ["text_projection", "proj"]:
408
+ if hasattr(l, name):
409
+ attr = getattr(l, name)
410
+ if attr is not None:
411
+ attr.data = attr.data.half()
412
+
413
+ model.apply(_convert_weights_to_fp16)
414
+
415
+
416
+ def build_model(state_dict: dict):
417
+ vit = "visual.proj" in state_dict
418
+
419
+ if vit:
420
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
421
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
422
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
423
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
424
+ image_resolution = vision_patch_size * grid_size
425
+ else:
426
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
427
+ vision_layers = tuple(counts)
428
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
429
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
430
+ vision_patch_size = None
431
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
432
+ image_resolution = output_width * 32
433
+
434
+ embed_dim = state_dict["text_projection"].shape[1]
435
+ context_length = state_dict["positional_embedding"].shape[0]
436
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
437
+ transformer_width = state_dict["ln_final.weight"].shape[0]
438
+ transformer_heads = transformer_width // 64
439
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
440
+
441
+ model = CLIP(
442
+ embed_dim,
443
+ image_resolution, vision_layers, vision_width, vision_patch_size,
444
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
445
+ )
446
+
447
+ for key in ["input_resolution", "context_length", "vocab_size"]:
448
+ if key in state_dict:
449
+ del state_dict[key]
450
+
451
+ # convert_weights(model)
452
+ model.load_state_dict(state_dict)
453
+ return model.eval()
feature_networks/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
feature_networks/constants.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TORCHVISION = [
2
+ "vgg11_bn",
3
+ "vgg13_bn",
4
+ "vgg16",
5
+ "vgg16_bn",
6
+ "vgg19_bn",
7
+ "densenet121",
8
+ "densenet169",
9
+ "densenet201",
10
+ "inception_v3",
11
+ "resnet18",
12
+ "resnet34",
13
+ "resnet50",
14
+ "resnet101",
15
+ "resnet152",
16
+ "shufflenet_v2_x0_5",
17
+ "mobilenet_v2",
18
+ "wide_resnet50_2",
19
+ "mnasnet0_5",
20
+ "mnasnet1_0",
21
+ "ghostnet_100",
22
+ "cspresnet50",
23
+ "fbnetc_100",
24
+ "spnasnet_100",
25
+ "resnet50d",
26
+ "resnet26",
27
+ "resnet26d",
28
+ "seresnet50",
29
+ "resnetblur50",
30
+ "resnetrs50",
31
+ "tf_mixnet_s",
32
+ "tf_mixnet_m",
33
+ "tf_mixnet_l",
34
+ "ese_vovnet19b_dw",
35
+ "ese_vovnet39b",
36
+ "res2next50",
37
+ "gernet_s",
38
+ "gernet_m",
39
+ "repvgg_a2",
40
+ "repvgg_b0",
41
+ "repvgg_b1",
42
+ "repvgg_b1g4",
43
+ "revnet",
44
+ "dm_nfnet_f1",
45
+ "nfnet_l0",
46
+ ]
47
+
48
+ REGNETS = [
49
+ "regnetx_002",
50
+ "regnetx_004",
51
+ "regnetx_006",
52
+ "regnetx_008",
53
+ "regnetx_016",
54
+ "regnetx_032",
55
+ "regnetx_040",
56
+ "regnetx_064",
57
+ "regnety_002",
58
+ "regnety_004",
59
+ "regnety_006",
60
+ "regnety_008",
61
+ "regnety_016",
62
+ "regnety_032",
63
+ "regnety_040",
64
+ "regnety_064",
65
+ ]
66
+
67
+ EFFNETS_IMAGENET = [
68
+ 'tf_efficientnet_b0',
69
+ 'tf_efficientnet_b1',
70
+ 'tf_efficientnet_b2',
71
+ 'tf_efficientnet_b3',
72
+ 'tf_efficientnet_b4',
73
+ 'tf_efficientnet_b0_ns',
74
+ ]
75
+
76
+ EFFNETS_INCEPTION = [
77
+ 'tf_efficientnet_lite0',
78
+ 'tf_efficientnet_lite1',
79
+ 'tf_efficientnet_lite2',
80
+ 'tf_efficientnet_lite3',
81
+ 'tf_efficientnet_lite4',
82
+ 'tf_efficientnetv2_b0',
83
+ 'tf_efficientnetv2_b1',
84
+ 'tf_efficientnetv2_b2',
85
+ 'tf_efficientnetv2_b3',
86
+ 'efficientnet_b1',
87
+ 'efficientnet_b1_pruned',
88
+ 'efficientnet_b2_pruned',
89
+ 'efficientnet_b3_pruned',
90
+ ]
91
+
92
+ EFFNETS = EFFNETS_IMAGENET + EFFNETS_INCEPTION
93
+
94
+ VITS_IMAGENET = [
95
+ 'deit_tiny_distilled_patch16_224',
96
+ 'deit_small_distilled_patch16_224',
97
+ 'deit_base_distilled_patch16_224',
98
+ ]
99
+
100
+ VITS_INCEPTION = [
101
+ 'vit_base_patch16_224',
102
+ 'vit_large_patch16_224'
103
+ ]
104
+
105
+ MAES = [
106
+ 'mae_vit_base_patch16',
107
+ 'mae_vit_large_patch16',
108
+ 'mae_vit_huge_patch14'
109
+ ]
110
+
111
+ ST= ['swin_base_patch4_window7_224']
112
+
113
+ TNT = ['tnt_b_patch16_224']
114
+
115
+ VITS = VITS_IMAGENET + VITS_INCEPTION
116
+
117
+ CLIP = [
118
+ 'resnet50_clip'
119
+ ]
120
+
121
+ ALL_MODELS = TORCHVISION + REGNETS + EFFNETS + VITS + CLIP + MAES + TNT
122
+
123
+ # Group according to input normalization
124
+
125
+ NORMALIZED_IMAGENET = TORCHVISION + REGNETS + EFFNETS_IMAGENET + VITS_IMAGENET
126
+
127
+ NORMALIZED_INCEPTION = EFFNETS_INCEPTION + VITS_INCEPTION + MAES + TNT + ST
128
+
129
+ NORMALIZED_CLIP = CLIP
feature_networks/pretrained_builder.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as zoomodels
5
+ from torch.autograd import Function
6
+
7
+ import timm
8
+
9
+ from feature_networks import clip
10
+ from feature_networks.vit import _make_vit_b16_backbone, forward_vit
11
+ from feature_networks.constants import ALL_MODELS, VITS, EFFNETS, REGNETS
12
+ from pg_modules.blocks import Interpolate
13
+
14
+ def _feature_splitter(model, idcs):
15
+ pretrained = nn.Module()
16
+ pretrained.layer0 = nn.Sequential(model.features[:idcs[0]])
17
+ pretrained.layer1 = nn.Sequential(model.features[idcs[0]:idcs[1]])
18
+ pretrained.layer2 = nn.Sequential(model.features[idcs[1]:idcs[2]])
19
+ pretrained.layer3 = nn.Sequential(model.features[idcs[2]:idcs[3]])
20
+ return pretrained
21
+
22
+ def _make_resnet(model):
23
+ pretrained = nn.Module()
24
+ pretrained.layer0 = nn.Sequential(
25
+ model.conv1, model.bn1, model.relu, model.maxpool, model.layer1,
26
+ )
27
+ pretrained.layer1 = model.layer2
28
+ pretrained.layer2 = model.layer3
29
+ pretrained.layer3 = model.layer4
30
+ return pretrained
31
+
32
+ def _make_regnet(model):
33
+ pretrained = nn.Module()
34
+ pretrained.layer0 = nn.Sequential(
35
+ model.stem, model.s1
36
+ )
37
+ pretrained.layer1 = model.s2
38
+ pretrained.layer2 = model.s3
39
+ pretrained.layer3 = model.s4
40
+ return pretrained
41
+
42
+ def _make_nfnet(model):
43
+ pretrained = nn.Module()
44
+ pretrained.layer0 = nn.Sequential(
45
+ model.stem, model.stages[0]
46
+ )
47
+ pretrained.layer1 = model.stages[1]
48
+ pretrained.layer2 = model.stages[2]
49
+ pretrained.layer3 = model.stages[3]
50
+ return pretrained
51
+
52
+ def _make_resnet_v2(model):
53
+ pretrained = nn.Module()
54
+ pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
55
+ pretrained.layer1 = model.stages[1]
56
+ pretrained.layer2 = model.stages[2]
57
+ pretrained.layer3 = model.stages[3]
58
+ return pretrained
59
+
60
+ def _make_resnet_clip(model):
61
+ pretrained = nn.Module()
62
+
63
+ # slightly more complicated than the standard resnet
64
+ pretrained.layer0 = nn.Sequential(
65
+ model.conv1,
66
+ model.bn1,
67
+ model.relu,
68
+ model.conv2,
69
+ model.bn2,
70
+ model.relu,
71
+ model.conv3,
72
+ model.bn3,
73
+ model.relu,
74
+ model.avgpool,
75
+ model.layer1,
76
+ )
77
+
78
+ pretrained.layer1 = model.layer2
79
+ pretrained.layer2 = model.layer3
80
+ pretrained.layer3 = model.layer4
81
+
82
+ return pretrained
83
+
84
+ def _make_densenet(model):
85
+ pretrained = nn.Module()
86
+
87
+ pretrained.layer0 = model.features[:6]
88
+
89
+ pretrained.layer1 = model.features[6:8]
90
+ pretrained.layer1[-1][-1] = nn.Identity()
91
+ pretrained.layer1 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer1)
92
+
93
+ pretrained.layer2 = model.features[8:10]
94
+ pretrained.layer2[-1][-1] = nn.Identity()
95
+ pretrained.layer2 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer2)
96
+
97
+ pretrained.layer3 = model.features[10:12]
98
+ pretrained.layer3 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer3)
99
+
100
+ return pretrained
101
+
102
+ def _make_shufflenet(model):
103
+ pretrained = nn.Module()
104
+ pretrained.layer0 = nn.Sequential(model.conv1, model.maxpool)
105
+ pretrained.layer1 = model.stage2
106
+ pretrained.layer2 = model.stage3
107
+ pretrained.layer3 = model.stage4
108
+ return pretrained
109
+
110
+ def _make_cspresnet(model):
111
+ pretrained = nn.Module()
112
+ pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
113
+ pretrained.layer1 = model.stages[1]
114
+ pretrained.layer2 = model.stages[2]
115
+ pretrained.layer3 = model.stages[3]
116
+ return pretrained
117
+
118
+ def _make_efficientnet(model):
119
+ pretrained = nn.Module()
120
+ pretrained.layer0 = nn.Sequential(
121
+ model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]
122
+ )
123
+ pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
124
+ pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
125
+ pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
126
+ return pretrained
127
+
128
+ def _make_ghostnet(model):
129
+ pretrained = nn.Module()
130
+ pretrained.layer0 = nn.Sequential(
131
+ model.conv_stem, model.bn1, model.act1, *model.blocks[0:3],
132
+ )
133
+ pretrained.layer1 = nn.Sequential(*model.blocks[3:5])
134
+ pretrained.layer2 = nn.Sequential(*model.blocks[5:7])
135
+ pretrained.layer3 = nn.Sequential(*model.blocks[7:-1])
136
+ return pretrained
137
+
138
+ def _make_vit(model, name):
139
+ if 'tiny' in name:
140
+ features = [24, 48, 96, 192]
141
+ hooks = [2, 5, 8, 11]
142
+ vit_features = 192
143
+
144
+ elif 'small' in name:
145
+ features = [48, 96, 192, 384]
146
+ hooks = [2, 5, 8, 11]
147
+ vit_features = 384
148
+
149
+ elif 'base' in name:
150
+ features = [96, 192, 384, 768]
151
+ hooks = [2, 5, 8, 11]
152
+ vit_features = 768
153
+
154
+ elif 'large' in name:
155
+ features = [256, 512, 1024, 1024]
156
+ hooks = [5, 11, 17, 23]
157
+ vit_features = 1024
158
+
159
+ else:
160
+ raise NotImplementedError('Invalid ViT backbone not available')
161
+
162
+ return _make_vit_b16_backbone(
163
+ model,
164
+ features=features,
165
+ size=[224, 224],
166
+ hooks=hooks,
167
+ vit_features=vit_features,
168
+ start_index=2 if 'deit' in name else 1,
169
+ )
170
+
171
+ def calc_dims(pretrained, is_vit=False):
172
+ dims = []
173
+ inp_res = 256
174
+ tmp = torch.zeros(1, 3, inp_res, inp_res)
175
+
176
+ if not is_vit:
177
+ tmp = pretrained.layer0(tmp)
178
+ dims.append(tmp.shape[1:3])
179
+ tmp = pretrained.layer1(tmp)
180
+ dims.append(tmp.shape[1:3])
181
+ tmp = pretrained.layer2(tmp)
182
+ dims.append(tmp.shape[1:3])
183
+ tmp = pretrained.layer3(tmp)
184
+ dims.append(tmp.shape[1:3])
185
+ else:
186
+ tmp = forward_vit(pretrained, tmp)
187
+ dims = [out.shape[1:3] for out in tmp]
188
+
189
+ # split to channels and resolution multiplier
190
+ dims = np.array(dims)
191
+ channels = dims[:, 0]
192
+ res_mult = dims[:, 1] / inp_res
193
+ return channels, res_mult
194
+
195
+ def _make_pretrained(backbone, verbose=False):
196
+ assert backbone in ALL_MODELS
197
+
198
+ if backbone == 'vgg11_bn':
199
+ model = zoomodels.__dict__[backbone](True)
200
+ idcs = [7, 14, 21, 28]
201
+ pretrained = _feature_splitter(model, idcs)
202
+
203
+ elif backbone == 'vgg13_bn':
204
+ model = zoomodels.__dict__[backbone](True)
205
+ idcs = [13, 20, 27, 34]
206
+ pretrained = _feature_splitter(model, idcs)
207
+
208
+ elif backbone == 'vgg16_bn':
209
+ model = zoomodels.__dict__[backbone](True)
210
+ idcs = [13, 23, 33, 43]
211
+ pretrained = _feature_splitter(model, idcs)
212
+
213
+ elif backbone == 'vgg19_bn':
214
+ model = zoomodels.__dict__[backbone](True)
215
+ idcs = [13, 26, 39, 52]
216
+ pretrained = _feature_splitter(model, idcs)
217
+
218
+ elif backbone == 'densenet121':
219
+ model = zoomodels.__dict__[backbone](True)
220
+ pretrained = _make_densenet(model)
221
+
222
+ elif backbone == 'densenet169':
223
+ model = zoomodels.__dict__[backbone](True)
224
+ pretrained = _make_densenet(model)
225
+
226
+ elif backbone == 'densenet201':
227
+ model = zoomodels.__dict__[backbone](True)
228
+ pretrained = _make_densenet(model)
229
+
230
+ elif backbone == 'resnet18':
231
+ model = zoomodels.__dict__[backbone](True)
232
+ pretrained = _make_resnet(model)
233
+
234
+ elif backbone == 'resnet34':
235
+ model = zoomodels.__dict__[backbone](True)
236
+ pretrained = _make_resnet(model)
237
+
238
+ elif backbone == 'resnet50':
239
+ model = zoomodels.__dict__[backbone](True)
240
+ pretrained = _make_resnet(model)
241
+
242
+ elif backbone == 'resnet101':
243
+ model = zoomodels.__dict__[backbone](True)
244
+ pretrained = _make_resnet(model)
245
+
246
+ elif backbone == 'resnet152':
247
+ model = zoomodels.__dict__[backbone](True)
248
+ pretrained = _make_resnet(model)
249
+
250
+ elif backbone == 'wide_resnet50_2':
251
+ model = zoomodels.__dict__[backbone](True)
252
+ pretrained = _make_resnet(model)
253
+
254
+ elif backbone == 'wide_resnet101_2':
255
+ model = zoomodels.__dict__[backbone](True)
256
+ pretrained = _make_resnet(model)
257
+
258
+ elif backbone == 'shufflenet_v2_x0_5':
259
+ model = zoomodels.__dict__[backbone](True)
260
+ pretrained = _make_shufflenet(model)
261
+
262
+ elif backbone == 'mobilenet_v2':
263
+ model = zoomodels.__dict__[backbone](True)
264
+ idcs = [4, 7, 14, 18]
265
+ pretrained = _feature_splitter(model, idcs) # same structure as vgg
266
+
267
+ elif backbone == 'mnasnet0_5':
268
+ model = zoomodels.__dict__[backbone](True)
269
+ model.features = model.layers
270
+ idcs = [9, 10, 12, 14]
271
+ pretrained = _feature_splitter(model, idcs)
272
+
273
+ elif backbone == 'mnasnet1_0':
274
+ model = zoomodels.__dict__[backbone](True)
275
+ model.features = model.layers
276
+ idcs = [9, 10, 12, 14]
277
+ pretrained = _feature_splitter(model, idcs)
278
+
279
+ elif backbone == 'ghostnet_100':
280
+ model = timm.create_model(backbone, pretrained=True)
281
+ pretrained = _make_ghostnet(model)
282
+
283
+ elif backbone == 'cspresnet50':
284
+ model = timm.create_model(backbone, pretrained=True)
285
+ pretrained = _make_cspresnet(model)
286
+
287
+ elif backbone == 'fbnetc_100':
288
+ model = timm.create_model(backbone, pretrained=True)
289
+ pretrained = _make_efficientnet(model)
290
+
291
+ elif backbone == 'spnasnet_100':
292
+ model = timm.create_model(backbone, pretrained=True)
293
+ pretrained = _make_efficientnet(model)
294
+
295
+ elif backbone == 'resnet50d':
296
+ model = timm.create_model(backbone, pretrained=True)
297
+ model.relu = model.act1
298
+ pretrained = _make_resnet(model)
299
+
300
+ elif backbone == 'resnet26':
301
+ model = timm.create_model(backbone, pretrained=True)
302
+ model.relu = model.act1
303
+ pretrained = _make_resnet(model)
304
+
305
+ elif backbone == 'resnet26d':
306
+ model = timm.create_model(backbone, pretrained=True)
307
+ model.relu = model.act1
308
+ pretrained = _make_resnet(model)
309
+
310
+ elif backbone == 'seresnet50':
311
+ model = timm.create_model(backbone, pretrained=True)
312
+ model.relu = model.act1
313
+ pretrained = _make_resnet(model)
314
+
315
+ elif backbone == 'resnetblur50':
316
+ model = timm.create_model(backbone, pretrained=True)
317
+ model.relu = model.act1
318
+ pretrained = _make_resnet(model)
319
+
320
+ elif backbone == 'resnetrs50':
321
+ model = timm.create_model(backbone, pretrained=True)
322
+ model.relu = model.act1
323
+ pretrained = _make_resnet(model)
324
+
325
+ elif backbone == 'tf_mixnet_s':
326
+ model = timm.create_model(backbone, pretrained=True)
327
+ pretrained = _make_efficientnet(model)
328
+
329
+ elif backbone == 'tf_mixnet_m':
330
+ model = timm.create_model(backbone, pretrained=True)
331
+ pretrained = _make_efficientnet(model)
332
+
333
+ elif backbone == 'tf_mixnet_l':
334
+ model = timm.create_model(backbone, pretrained=True)
335
+ pretrained = _make_efficientnet(model)
336
+
337
+ elif backbone == 'dm_nfnet_f0':
338
+ model = timm.create_model(backbone, pretrained=True)
339
+ pretrained = _make_cspresnet(model)
340
+
341
+ elif backbone == 'dm_nfnet_f1':
342
+ model = timm.create_model(backbone, pretrained=True)
343
+ pretrained = _make_cspresnet(model)
344
+
345
+ elif backbone == 'ese_vovnet19b_dw':
346
+ model = timm.create_model(backbone, pretrained=True)
347
+ pretrained = _make_cspresnet(model)
348
+
349
+ elif backbone == 'ese_vovnet39b':
350
+ model = timm.create_model(backbone, pretrained=True)
351
+ pretrained = _make_cspresnet(model)
352
+
353
+ elif backbone == 'res2next50':
354
+ model = timm.create_model(backbone, pretrained=True)
355
+ model.relu = model.act1
356
+ pretrained = _make_resnet(model)
357
+
358
+ elif backbone == 'gernet_s':
359
+ model = timm.create_model(backbone, pretrained=True)
360
+ pretrained = _make_cspresnet(model)
361
+
362
+ elif backbone == 'gernet_m':
363
+ model = timm.create_model(backbone, pretrained=True)
364
+ pretrained = _make_cspresnet(model)
365
+
366
+ elif backbone == 'repvgg_a2':
367
+ model = timm.create_model(backbone, pretrained=True)
368
+ pretrained = _make_cspresnet(model)
369
+
370
+ elif backbone == 'repvgg_b0':
371
+ model = timm.create_model(backbone, pretrained=True)
372
+ pretrained = _make_cspresnet(model)
373
+
374
+ elif backbone == 'repvgg_b1':
375
+ model = timm.create_model(backbone, pretrained=True)
376
+ pretrained = _make_cspresnet(model)
377
+
378
+ elif backbone == 'repvgg_b1g4':
379
+ model = timm.create_model(backbone, pretrained=True)
380
+ pretrained = _make_cspresnet(model)
381
+
382
+ elif backbone == 'dm_nfnet_f1':
383
+ model = timm.create_model(backbone, pretrained=True)
384
+ pretrained = _make_nfnet(model)
385
+
386
+ elif backbone == 'nfnet_l0':
387
+ model = timm.create_model(backbone, pretrained=True)
388
+ pretrained = _make_nfnet(model)
389
+
390
+ elif backbone in REGNETS:
391
+ model = timm.create_model(backbone, pretrained=True)
392
+ pretrained = _make_regnet(model)
393
+
394
+ elif backbone in EFFNETS:
395
+ model = timm.create_model(backbone, pretrained=True)
396
+ pretrained = _make_efficientnet(model)
397
+
398
+ elif backbone in VITS:
399
+ model = timm.create_model(backbone, pretrained=True)
400
+ pretrained = _make_vit(model, backbone)
401
+
402
+ elif backbone == 'resnet50_clip':
403
+ model = clip.load('RN50', device='cpu', jit=False)[0].visual
404
+ pretrained = _make_resnet_clip(model)
405
+
406
+ else:
407
+ raise NotImplementedError('Wrong model name?')
408
+
409
+ pretrained.CHANNELS, pretrained.RES_MULT = calc_dims(pretrained, is_vit=backbone in VITS)
410
+
411
+ if verbose:
412
+ print(f"Succesfully loaded: {backbone}")
413
+ print(f"Channels: {pretrained.CHANNELS}")
414
+ print(f"Resolution Multiplier: {pretrained.RES_MULT}")
415
+ print(f"Out Res for 256 : {pretrained.RES_MULT*256}")
416
+
417
+ return pretrained
feature_networks/vit.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x.contiguous()
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ lantent,_ = pretrained.model.forward_flex(x)
60
+
61
+
62
+ layer_1 = pretrained.activations["1"]
63
+ layer_2 = pretrained.activations["2"]
64
+ layer_3 = pretrained.activations["3"]
65
+ layer_4 = pretrained.activations["4"]
66
+
67
+ layer_1 = pretrained.layer1[0:2](layer_1)
68
+ layer_2 = pretrained.layer2[0:2](layer_2)
69
+ layer_3 = pretrained.layer3[0:2](layer_3)
70
+ layer_4 = pretrained.layer4[0:2](layer_4)
71
+
72
+ unflatten = nn.Sequential(
73
+ nn.Unflatten(
74
+ 2,
75
+ torch.Size(
76
+ [
77
+ h // pretrained.model.patch_size[1],
78
+ w // pretrained.model.patch_size[0],
79
+ ]
80
+ ),
81
+ )
82
+ )
83
+
84
+ if layer_1.ndim == 3:
85
+ layer_1 = unflatten(layer_1)
86
+ if layer_2.ndim == 3:
87
+ layer_2 = unflatten(layer_2)
88
+ if layer_3.ndim == 3:
89
+ layer_3 = unflatten(layer_3)
90
+ if layer_4.ndim == 3:
91
+ layer_4 = unflatten(layer_4)
92
+
93
+ layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
94
+ layer_2 = pretrained.layer2[3 : len(pretrained.layer2)](layer_2)
95
+ layer_3 = pretrained.layer3[3 : len(pretrained.layer3)](layer_3)
96
+ layer_4 = pretrained.layer4[3 : len(pretrained.layer4)](layer_4)
97
+
98
+ return layer_1, layer_2, layer_3, layer_4
99
+
100
+ def forward_swin(pretrained, x):
101
+ b, c, h, w = x.shape
102
+
103
+ lantent,_ = pretrained.model.forward_flex(x)
104
+
105
+
106
+ layer_1 = pretrained.activations["1"]
107
+ layer_2 = pretrained.activations["2"]
108
+ layer_3 = pretrained.activations["3"]
109
+ layer_4 = pretrained.activations["4"]
110
+
111
+ layer_1 = pretrained.layer1[0:2](layer_1)
112
+ layer_2 = pretrained.layer2[0:2](layer_2)
113
+ layer_3 = pretrained.layer3[0:2](layer_3)
114
+ layer_4 = pretrained.layer4[0:2](layer_4)
115
+
116
+ unflatten = nn.Sequential(
117
+ nn.Unflatten(
118
+ 2,
119
+ torch.Size(
120
+ [
121
+ h // pretrained.model.patch_size[1],
122
+ w // pretrained.model.patch_size[0],
123
+ ]
124
+ ),
125
+ )
126
+ )
127
+
128
+ if layer_1.ndim == 3:
129
+ layer_1 = unflatten(layer_1)
130
+ if layer_2.ndim == 3:
131
+ layer_2 = unflatten(layer_2)
132
+ if layer_3.ndim == 3:
133
+ layer_3 = unflatten(layer_3)
134
+ if layer_4.ndim == 3:
135
+ layer_4 = unflatten(layer_4)
136
+
137
+ layer_1 = pretrained.layer1[3 : len(pretrained.layer1)](layer_1)
138
+ layer_2 = pretrained.layer2[3 : len(pretrained.layer2)](layer_2)
139
+ layer_3 = pretrained.layer3[3 : len(pretrained.layer3)](layer_3)
140
+ layer_4 = pretrained.layer4[3 : len(pretrained.layer4)](layer_4)
141
+
142
+ return layer_1, layer_2, layer_3, layer_4
143
+
144
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
145
+ posemb_tok, posemb_grid = (
146
+ posemb[:, : self.start_index],
147
+ posemb[0, self.start_index :],
148
+ )
149
+
150
+ gs_old = int(math.sqrt(len(posemb_grid)))
151
+
152
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
153
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
154
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
155
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
156
+
157
+ return posemb
158
+
159
+
160
+ def forward_flex(self, x):
161
+ b, c, h, w = x.shape
162
+ # print(x.shape, self.OOD2ID)
163
+ # x = self.OOD2ID(x)
164
+
165
+ pos_embed = self._resize_pos_embed(
166
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
167
+ )
168
+
169
+ B = x.shape[0]
170
+
171
+ if hasattr(self.patch_embed, "backbone"):
172
+ x = self.patch_embed.backbone(x)
173
+ if isinstance(x, (list, tuple)):
174
+ x = x[-1] # last feature if backbone outputs list/tuple of features
175
+
176
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
177
+
178
+ if hasattr(self, "dist_token") and self.dist_token is not None:
179
+ cls_tokens = self.cls_token.expand(
180
+ B, -1, -1
181
+ ) # stole cls_tokens impl from Phil Wang, thanks
182
+ dist_token = self.dist_token.expand(B, -1, -1)
183
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
184
+ else:
185
+ cls_tokens = self.cls_token.expand(
186
+ B, -1, -1
187
+ ) # stole cls_tokens impl from Phil Wang, thanks
188
+ x = torch.cat((cls_tokens, x), dim=1)
189
+
190
+ x = x + pos_embed
191
+ x = self.pos_drop(x)
192
+
193
+ for blk in self.blocks:
194
+ x = blk(x)
195
+
196
+ x = self.norm(x)
197
+
198
+ return x, None
199
+
200
+ def forward_flex_swin(self, x):
201
+ x = self.patch_embed(x)
202
+ if self.absolute_pos_embed is not None:
203
+ x = x + self.absolute_pos_embed
204
+ x = self.pos_drop(x)
205
+ x = self.layers(x)
206
+ x = self.norm(x) # B L C
207
+
208
+ return x
209
+
210
+
211
+
212
+ activations = {}
213
+
214
+
215
+ def get_activation(name):
216
+ def hook(model, input, output):
217
+ activations[name] = output
218
+
219
+ return hook
220
+
221
+
222
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
223
+ if use_readout == "ignore":
224
+ readout_oper = [Slice(start_index)] * len(features)
225
+ elif use_readout == "add":
226
+ readout_oper = [AddReadout(start_index)] * len(features)
227
+ elif use_readout == "project":
228
+ readout_oper = [
229
+ ProjectReadout(vit_features, start_index) for out_feat in features
230
+ ]
231
+ else:
232
+ assert (
233
+ False
234
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
235
+
236
+ return readout_oper
237
+
238
+
239
+ def _make_vit_b16_backbone(
240
+ model,
241
+ features=[96, 192, 384, 768],
242
+ size=[384, 384],
243
+ hooks=[2, 5, 8, 11],
244
+ vit_features=768,
245
+ use_readout="ignore",
246
+ start_index=1,
247
+ ):
248
+ pretrained = nn.Module()
249
+
250
+ pretrained.model = model
251
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
252
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
253
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
254
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
255
+
256
+ pretrained.activations = activations
257
+
258
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
259
+
260
+ # 32, 48, 136, 384
261
+ pretrained.layer1 = nn.Sequential(
262
+ readout_oper[0],
263
+ Transpose(1, 2),
264
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
265
+ nn.Conv2d(
266
+ in_channels=vit_features,
267
+ out_channels=features[0],
268
+ kernel_size=1,
269
+ stride=1,
270
+ padding=0,
271
+ ),
272
+ nn.ConvTranspose2d(
273
+ in_channels=features[0],
274
+ out_channels=features[0],
275
+ kernel_size=4,
276
+ stride=4,
277
+ padding=0,
278
+ bias=True,
279
+ dilation=1,
280
+ groups=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.layer2 = nn.Sequential(
285
+ readout_oper[1],
286
+ Transpose(1, 2),
287
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
288
+ nn.Conv2d(
289
+ in_channels=vit_features,
290
+ out_channels=features[1],
291
+ kernel_size=1,
292
+ stride=1,
293
+ padding=0,
294
+ ),
295
+ nn.ConvTranspose2d(
296
+ in_channels=features[1],
297
+ out_channels=features[1],
298
+ kernel_size=2,
299
+ stride=2,
300
+ padding=0,
301
+ bias=True,
302
+ dilation=1,
303
+ groups=1,
304
+ ),
305
+ )
306
+
307
+ pretrained.layer3 = nn.Sequential(
308
+ readout_oper[2],
309
+ Transpose(1, 2),
310
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
311
+ nn.Conv2d(
312
+ in_channels=vit_features,
313
+ out_channels=features[2],
314
+ kernel_size=1,
315
+ stride=1,
316
+ padding=0,
317
+ ),
318
+ )
319
+
320
+ pretrained.layer4 = nn.Sequential(
321
+ readout_oper[3],
322
+ Transpose(1, 2),
323
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
324
+ nn.Conv2d(
325
+ in_channels=vit_features,
326
+ out_channels=features[3],
327
+ kernel_size=1,
328
+ stride=1,
329
+ padding=0,
330
+ ),
331
+ nn.Conv2d(
332
+ in_channels=features[3],
333
+ out_channels=features[3],
334
+ kernel_size=3,
335
+ stride=2,
336
+ padding=1,
337
+ ),
338
+ )
339
+
340
+ pretrained.model.start_index = start_index
341
+ pretrained.model.patch_size = [16, 16]
342
+
343
+ # We inject this function into the VisionTransformer instances so that
344
+ # we can use it with interpolated position embeddings without modifying the library source.
345
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
346
+ pretrained.model._resize_pos_embed = types.MethodType(
347
+ _resize_pos_embed, pretrained.model
348
+ )
349
+
350
+ return pretrained
351
+
352
+ def _make_swin_b16_backbone(
353
+ model,
354
+ features=[96, 192, 384, 768],
355
+ size=[384, 384],
356
+ hooks=[2, 5, 8, 11],
357
+ vit_features=768,
358
+ use_readout="ignore",
359
+ start_index=1,
360
+ ):
361
+ pretrained = nn.Module()
362
+
363
+ pretrained.model = model
364
+ pretrained.model.blocks[hooks[0]].blocks[-1].register_forward_hook(get_activation("1"))
365
+ pretrained.model.blocks[hooks[1]].blocks[-1].register_forward_hook(get_activation("2"))
366
+ pretrained.model.blocks[hooks[2]].blocks[-1].register_forward_hook(get_activation("3"))
367
+ pretrained.model.blocks[hooks[3]].blocks[-1].register_forward_hook(get_activation("4"))
368
+
369
+ pretrained.activations = activations
370
+
371
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
372
+
373
+ # 32, 48, 136, 384
374
+ pretrained.layer1 = nn.Sequential(
375
+ readout_oper[0],
376
+ Transpose(1, 2),
377
+ nn.Unflatten(2, torch.Size([size[0] // 4, size[1] // 4])),
378
+ nn.Conv2d(
379
+ in_channels=vit_features,
380
+ out_channels=features[0],
381
+ kernel_size=1,
382
+ stride=1,
383
+ padding=0,
384
+ ),
385
+ )
386
+
387
+ pretrained.layer2 = nn.Sequential(
388
+ readout_oper[1],
389
+ Transpose(1, 2),
390
+ nn.Unflatten(2, torch.Size([size[0] // 8, size[1] // 8])),
391
+ nn.Conv2d(
392
+ in_channels=vit_features*2,
393
+ out_channels=features[1],
394
+ kernel_size=1,
395
+ stride=1,
396
+ padding=0,
397
+ ),
398
+ )
399
+
400
+ pretrained.layer3 = nn.Sequential(
401
+ readout_oper[2],
402
+ Transpose(1, 2),
403
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
404
+ nn.Conv2d(
405
+ in_channels=vit_features*4,
406
+ out_channels=features[2],
407
+ kernel_size=1,
408
+ stride=1,
409
+ padding=0,
410
+ ),
411
+ )
412
+
413
+ pretrained.layer4 = nn.Sequential(
414
+ readout_oper[3],
415
+ Transpose(1, 2),
416
+ nn.Unflatten(2, torch.Size([size[0] // 32, size[1] // 32])),
417
+ nn.Conv2d(
418
+ in_channels=vit_features*8,
419
+ out_channels=features[3],
420
+ kernel_size=1,
421
+ stride=1,
422
+ padding=0,
423
+ ),
424
+ )
425
+
426
+ pretrained.model.start_index = start_index
427
+ pretrained.model.patch_size = [16, 16]
428
+
429
+ # We inject this function into the VisionTransformer instances so that
430
+ # we can use it with interpolated position embeddings without modifying the library source.
431
+ pretrained.model.forward_flex = types.MethodType(forward_flex_swin, pretrained.model)
432
+ pretrained.model._resize_pos_embed = types.MethodType(
433
+ _resize_pos_embed, pretrained.model
434
+ )
435
+
436
+ return pretrained
legacy.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import io
18
+ import dnnlib
19
+ import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def load_network_pkl(f, force_fp16=False):
24
+ data = _LegacyUnpickler(f).load()
25
+
26
+ # Legacy TensorFlow pickle => convert.
27
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
28
+ tf_G, tf_D, tf_Gs = data
29
+ G = convert_tf_generator(tf_G)
30
+ D = convert_tf_discriminator(tf_D)
31
+ G_ema = convert_tf_generator(tf_Gs)
32
+ data = dict(G=G, D=D, G_ema=G_ema)
33
+
34
+ # Add missing fields.
35
+ if 'training_set_kwargs' not in data:
36
+ data['training_set_kwargs'] = None
37
+ if 'augment_pipe' not in data:
38
+ data['augment_pipe'] = None
39
+
40
+ # Validate contents.
41
+ assert isinstance(data['G'], torch.nn.Module)
42
+ assert isinstance(data['D'], torch.nn.Module)
43
+ assert isinstance(data['G_ema'], torch.nn.Module)
44
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
45
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
46
+
47
+ # Force FP16.
48
+ if force_fp16:
49
+ for key in ['G', 'D', 'G_ema']:
50
+ old = data[key]
51
+ kwargs = copy.deepcopy(old.init_kwargs)
52
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
53
+ fp16_kwargs.num_fp16_res = 4
54
+ fp16_kwargs.conv_clamp = 256
55
+ if kwargs != old.init_kwargs:
56
+ new = type(old)(**kwargs).eval().requires_grad_(False)
57
+ misc.copy_params_and_buffers(old, new, require_all=True)
58
+ data[key] = new
59
+ return data
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ class _TFNetworkStub(dnnlib.EasyDict):
64
+ pass
65
+
66
+ class _LegacyUnpickler(pickle.Unpickler):
67
+ def find_class(self, module, name):
68
+ # print(module,name)
69
+ if module == '__builtin__':
70
+ return
71
+ if module == 'dnnlib.tflib.network' and name == 'Network':
72
+ return _TFNetworkStub
73
+ if module == 'torch.storage' and name == '_load_from_bytes':
74
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
75
+ return super().find_class(module, name)
76
+
77
+ #----------------------------------------------------------------------------
78
+
79
+ def _collect_tf_params(tf_net):
80
+ # pylint: disable=protected-access
81
+ tf_params = dict()
82
+ def recurse(prefix, tf_net):
83
+ for name, value in tf_net.variables:
84
+ tf_params[prefix + name] = value
85
+ for name, comp in tf_net.components.items():
86
+ recurse(prefix + name + '/', comp)
87
+ recurse('', tf_net)
88
+ return tf_params
89
+
90
+ #----------------------------------------------------------------------------
91
+
92
+ def _populate_module_params(module, *patterns):
93
+ for name, tensor in misc.named_params_and_buffers(module):
94
+ found = False
95
+ value = None
96
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
97
+ match = re.fullmatch(pattern, name)
98
+ if match:
99
+ found = True
100
+ if value_fn is not None:
101
+ value = value_fn(*match.groups())
102
+ break
103
+ try:
104
+ assert found
105
+ if value is not None:
106
+ tensor.copy_(torch.from_numpy(np.array(value)))
107
+ except:
108
+ print(name, list(tensor.shape))
109
+ raise
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ def convert_tf_generator(tf_G):
114
+ if tf_G.version < 4:
115
+ raise ValueError('TensorFlow pickle version too low')
116
+
117
+ # Collect kwargs.
118
+ tf_kwargs = tf_G.static_kwargs
119
+ known_kwargs = set()
120
+ def kwarg(tf_name, default=None, none=None):
121
+ known_kwargs.add(tf_name)
122
+ val = tf_kwargs.get(tf_name, default)
123
+ return val if val is not None else none
124
+
125
+ # Convert kwargs.
126
+ from pg_modules import networks_stylegan2
127
+ network_class = networks_stylegan2.Generator
128
+ kwargs = dnnlib.EasyDict(
129
+ z_dim = kwarg('latent_size', 512),
130
+ c_dim = kwarg('label_size', 0),
131
+ w_dim = kwarg('dlatent_size', 512),
132
+ img_resolution = kwarg('resolution', 1024),
133
+ img_channels = kwarg('num_channels', 3),
134
+ channel_base = kwarg('fmap_base', 16384) * 2,
135
+ channel_max = kwarg('fmap_max', 512),
136
+ num_fp16_res = kwarg('num_fp16_res', 0),
137
+ conv_clamp = kwarg('conv_clamp', None),
138
+ architecture = kwarg('architecture', 'skip'),
139
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
140
+ use_noise = kwarg('use_noise', True),
141
+ activation = kwarg('nonlinearity', 'lrelu'),
142
+ mapping_kwargs = dnnlib.EasyDict(
143
+ num_layers = kwarg('mapping_layers', 8),
144
+ embed_features = kwarg('label_fmaps', None),
145
+ layer_features = kwarg('mapping_fmaps', None),
146
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
147
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
148
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
149
+ ),
150
+ )
151
+
152
+ # Check for unknown kwargs.
153
+ kwarg('truncation_psi')
154
+ kwarg('truncation_cutoff')
155
+ kwarg('style_mixing_prob')
156
+ kwarg('structure')
157
+ kwarg('conditioning')
158
+ kwarg('fused_modconv')
159
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
160
+ if len(unknown_kwargs) > 0:
161
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
162
+
163
+ # Collect params.
164
+ tf_params = _collect_tf_params(tf_G)
165
+ for name, value in list(tf_params.items()):
166
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
167
+ if match:
168
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
169
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
170
+ kwargs.synthesis.kwargs.architecture = 'orig'
171
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
172
+
173
+ # Convert params.
174
+ G = network_class(**kwargs).eval().requires_grad_(False)
175
+ # pylint: disable=unnecessary-lambda
176
+ # pylint: disable=f-string-without-interpolation
177
+ _populate_module_params(G,
178
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
179
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
180
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
181
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
182
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
183
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
184
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
186
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
187
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
188
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
189
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
192
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
198
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
199
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
200
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
201
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
202
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
203
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
204
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
205
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
206
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
207
+ r'.*\.resample_filter', None,
208
+ r'.*\.act_filter', None,
209
+ )
210
+ return G
211
+
212
+ #----------------------------------------------------------------------------
213
+
214
+ def convert_tf_discriminator(tf_D):
215
+ if tf_D.version < 4:
216
+ raise ValueError('TensorFlow pickle version too low')
217
+
218
+ # Collect kwargs.
219
+ tf_kwargs = tf_D.static_kwargs
220
+ known_kwargs = set()
221
+ def kwarg(tf_name, default=None):
222
+ known_kwargs.add(tf_name)
223
+ return tf_kwargs.get(tf_name, default)
224
+
225
+ # Convert kwargs.
226
+ kwargs = dnnlib.EasyDict(
227
+ c_dim = kwarg('label_size', 0),
228
+ img_resolution = kwarg('resolution', 1024),
229
+ img_channels = kwarg('num_channels', 3),
230
+ architecture = kwarg('architecture', 'resnet'),
231
+ channel_base = kwarg('fmap_base', 16384) * 2,
232
+ channel_max = kwarg('fmap_max', 512),
233
+ num_fp16_res = kwarg('num_fp16_res', 0),
234
+ conv_clamp = kwarg('conv_clamp', None),
235
+ cmap_dim = kwarg('mapping_fmaps', None),
236
+ block_kwargs = dnnlib.EasyDict(
237
+ activation = kwarg('nonlinearity', 'lrelu'),
238
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
239
+ freeze_layers = kwarg('freeze_layers', 0),
240
+ ),
241
+ mapping_kwargs = dnnlib.EasyDict(
242
+ num_layers = kwarg('mapping_layers', 0),
243
+ embed_features = kwarg('mapping_fmaps', None),
244
+ layer_features = kwarg('mapping_fmaps', None),
245
+ activation = kwarg('nonlinearity', 'lrelu'),
246
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
247
+ ),
248
+ epilogue_kwargs = dnnlib.EasyDict(
249
+ mbstd_group_size = kwarg('mbstd_group_size', None),
250
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
251
+ activation = kwarg('nonlinearity', 'lrelu'),
252
+ ),
253
+ )
254
+
255
+ # Check for unknown kwargs.
256
+ kwarg('structure')
257
+ kwarg('conditioning')
258
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
259
+ if len(unknown_kwargs) > 0:
260
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
261
+
262
+ # Collect params.
263
+ tf_params = _collect_tf_params(tf_D)
264
+ for name, value in list(tf_params.items()):
265
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
266
+ if match:
267
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
268
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
269
+ kwargs.architecture = 'orig'
270
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
271
+
272
+ # Convert params.
273
+ #from pg_modules import networks_stylegan2
274
+ from pg_modules.discriminator import ProjectedDiscriminator
275
+
276
+ D = ProjectedDiscriminator(**kwargs).eval().requires_grad_(False)
277
+ # pylint: disable=unnecessary-lambda
278
+ # pylint: disable=f-string-without-interpolation
279
+ _populate_module_params(D,
280
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
281
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
282
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
283
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
284
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
285
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
286
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
287
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
288
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
289
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
290
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
291
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
292
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
293
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
294
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
295
+ r'.*\.resample_filter', None,
296
+ )
297
+ return D
298
+
299
+ #----------------------------------------------------------------------------
300
+
301
+ @click.command()
302
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
303
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
304
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
305
+ def convert_network_pickle(source, dest, force_fp16):
306
+ """Convert legacy network pickle into the native PyTorch format.
307
+
308
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
309
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
310
+
311
+ Example:
312
+
313
+ \b
314
+ python legacy.py \\
315
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
316
+ --dest=stylegan2-cat-config-f.pkl
317
+ """
318
+ print(f'Loading "{source}"...')
319
+ with dnnlib.util.open_url(source) as f:
320
+ data = load_network_pkl(f, force_fp16=force_fp16)
321
+ print(f'Saving "{dest}"...')
322
+ with open(dest, 'wb') as f:
323
+ pickle.dump(data, f)
324
+ print('Done.')
325
+
326
+ #----------------------------------------------------------------------------
327
+
328
+ if __name__ == "__main__":
329
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
330
+
331
+ #----------------------------------------------------------------------------
misc.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
68
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69
+
70
+ @contextlib.contextmanager
71
+ def suppress_tracer_warnings():
72
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73
+ warnings.filters.insert(0, flt)
74
+ yield
75
+ warnings.filters.remove(flt)
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+
147
+ def params_and_buffers(module):
148
+ assert isinstance(module, torch.nn.Module)
149
+ return list(module.parameters()) + list(module.buffers())
150
+
151
+ def named_params_and_buffers(module):
152
+ assert isinstance(module, torch.nn.Module)
153
+ return list(module.named_parameters()) + list(module.named_buffers())
154
+
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = dict(named_params_and_buffers(src_module))
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ try:
163
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
164
+ except:
165
+ continue
166
+
167
+ #----------------------------------------------------------------------------
168
+ # Context manager for easily enabling/disabling DistributedDataParallel
169
+ # synchronization.
170
+
171
+ @contextlib.contextmanager
172
+ def ddp_sync(module, sync):
173
+ assert isinstance(module, torch.nn.Module)
174
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
175
+ yield
176
+ else:
177
+ with module.no_sync():
178
+ yield
179
+
180
+ #----------------------------------------------------------------------------
181
+ # Check DistributedDataParallel consistency across processes.
182
+
183
+ def check_ddp_consistency(module, ignore_regex=None):
184
+ assert isinstance(module, torch.nn.Module)
185
+ for name, tensor in named_params_and_buffers(module):
186
+ fullname = type(module).__name__ + '.' + name
187
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
188
+ continue
189
+ tensor = tensor.detach()
190
+ if tensor.is_floating_point():
191
+ tensor = nan_to_num(tensor)
192
+ other = tensor.clone()
193
+ torch.distributed.broadcast(tensor=other, src=0)
194
+ assert (tensor == other).all(), fullname
195
+
196
+ #----------------------------------------------------------------------------
197
+ # Print summary table of module hierarchy.
198
+
199
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
200
+ assert isinstance(module, torch.nn.Module)
201
+ assert not isinstance(module, torch.jit.ScriptModule)
202
+ assert isinstance(inputs, (tuple, list))
203
+
204
+ # Register hooks.
205
+ entries = []
206
+ nesting = [0]
207
+ def pre_hook(_mod, _inputs):
208
+ nesting[0] += 1
209
+ def post_hook(mod, _inputs, outputs):
210
+ nesting[0] -= 1
211
+ if nesting[0] <= max_nesting:
212
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
213
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
214
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
215
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
216
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
217
+
218
+ # Run module.
219
+ outputs = module(*inputs)
220
+ for hook in hooks:
221
+ hook.remove()
222
+
223
+ # Identify unique outputs, parameters, and buffers.
224
+ tensors_seen = set()
225
+ for e in entries:
226
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
227
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
228
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
229
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
230
+
231
+ # Filter out redundant entries.
232
+ if skip_redundant:
233
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
234
+
235
+ # Construct table.
236
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
237
+ rows += [['---'] * len(rows[0])]
238
+ param_total = 0
239
+ buffer_total = 0
240
+ submodule_names = {mod: name for name, mod in module.named_modules()}
241
+ for e in entries:
242
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
243
+ param_size = sum(t.numel() for t in e.unique_params)
244
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
245
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
246
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
247
+ rows += [[
248
+ name + (':0' if len(e.outputs) >= 2 else ''),
249
+ str(param_size) if param_size else '-',
250
+ str(buffer_size) if buffer_size else '-',
251
+ (output_shapes + ['-'])[0],
252
+ (output_dtypes + ['-'])[0],
253
+ ]]
254
+ for idx in range(1, len(e.outputs)):
255
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
256
+ param_total += param_size
257
+ buffer_total += buffer_size
258
+ rows += [['---'] * len(rows[0])]
259
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
260
+
261
+ # Print table.
262
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
263
+ print()
264
+ for row in rows:
265
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
266
+ print()
267
+ return outputs
268
+
269
+ #----------------------------------------------------------------------------
270
+
271
+ # Added by Katja
272
+ import os
273
+
274
+ def get_ckpt_path(run_dir):
275
+ return os.path.join(run_dir, f'network-snapshot.pkl')
pg_modules/__init__.py ADDED
File without changes
pg_modules/__pycache__/MViT.cpython-39.pyc ADDED
Binary file (17.9 kB). View file
 
pg_modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (222 Bytes). View file
 
pg_modules/__pycache__/blocks.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
pg_modules/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (11.8 kB). View file
 
pg_modules/__pycache__/diffaug.cpython-38.pyc ADDED
Binary file (2.69 kB). View file
 
pg_modules/__pycache__/diffaug.cpython-39.pyc ADDED
Binary file (2.81 kB). View file
 
pg_modules/__pycache__/discriminator.cpython-38.pyc ADDED
Binary file (5.65 kB). View file
 
pg_modules/__pycache__/discriminator.cpython-39.pyc ADDED
Binary file (4.51 kB). View file
 
pg_modules/__pycache__/mae.cpython-39.pyc ADDED
Binary file (8.85 kB). View file
 
pg_modules/__pycache__/models_tnt.cpython-39.pyc ADDED
Binary file (17.5 kB). View file
 
pg_modules/__pycache__/networks_fastgan.cpython-38.pyc ADDED
Binary file (5.2 kB). View file
 
pg_modules/__pycache__/networks_fastgan.cpython-39.pyc ADDED
Binary file (5.34 kB). View file
 
pg_modules/__pycache__/networks_stylegan2.cpython-39.pyc ADDED
Binary file (15.5 kB). View file
 
pg_modules/__pycache__/projector.cpython-38.pyc ADDED
Binary file (3.85 kB). View file
 
pg_modules/__pycache__/projector.cpython-39.pyc ADDED
Binary file (4.21 kB). View file
 
pg_modules/__pycache__/simmim.cpython-39.pyc ADDED
Binary file (4.22 kB). View file
 
pg_modules/__pycache__/vision_transformer.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
pg_modules/blocks.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils import spectral_norm
6
+
7
+
8
+ ### single layers
9
+
10
+
11
+ def conv2d(*args, **kwargs):
12
+ return spectral_norm(nn.Conv2d(*args, **kwargs))
13
+
14
+
15
+ def convTranspose2d(*args, **kwargs):
16
+ return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
17
+
18
+
19
+ def embedding(*args, **kwargs):
20
+ return spectral_norm(nn.Embedding(*args, **kwargs))
21
+
22
+
23
+ def linear(*args, **kwargs):
24
+ return spectral_norm(nn.Linear(*args, **kwargs))
25
+
26
+
27
+ def NormLayer(c, mode='batch'):
28
+ if mode == 'group':
29
+ return nn.GroupNorm(c//2, c)
30
+ elif mode == 'batch':
31
+ return nn.BatchNorm2d(c)
32
+
33
+
34
+ ### Activations
35
+
36
+
37
+ class GLU(nn.Module):
38
+ def forward(self, x):
39
+ nc = x.size(1)
40
+ assert nc % 2 == 0, 'channels dont divide 2!'
41
+ nc = int(nc/2)
42
+ return x[:, :nc] * torch.sigmoid(x[:, nc:])
43
+
44
+
45
+ class Swish(nn.Module):
46
+ def forward(self, feat):
47
+ return feat * torch.sigmoid(feat)
48
+
49
+
50
+ ### Upblocks
51
+
52
+
53
+ class InitLayer(nn.Module):
54
+ def __init__(self, nz, channel, sz=4):
55
+ super().__init__()
56
+
57
+ self.init = nn.Sequential(
58
+ convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
59
+ NormLayer(channel*2),
60
+ GLU(),
61
+ )
62
+
63
+ def forward(self, noise):
64
+ noise = noise.view(noise.shape[0], -1, 1, 1)
65
+ return self.init(noise)
66
+
67
+
68
+ def UpBlockSmall(in_planes, out_planes):
69
+ block = nn.Sequential(
70
+ nn.Upsample(scale_factor=2, mode='nearest'),
71
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
72
+ NormLayer(out_planes*2), GLU())
73
+ return block
74
+
75
+
76
+ class UpBlockSmallCond(nn.Module):
77
+ def __init__(self, in_planes, out_planes, z_dim):
78
+ super().__init__()
79
+ self.in_planes = in_planes
80
+ self.out_planes = out_planes
81
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
82
+ self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
83
+
84
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
85
+ self.bn = which_bn(2*out_planes)
86
+ self.act = GLU()
87
+
88
+ def forward(self, x, c):
89
+ x = self.up(x)
90
+ x = self.conv(x)
91
+ x = self.bn(x, c)
92
+ x = self.act(x)
93
+ return x
94
+
95
+
96
+ def UpBlockBig(in_planes, out_planes):
97
+ block = nn.Sequential(
98
+ nn.Upsample(scale_factor=2, mode='nearest'),
99
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
100
+ NoiseInjection(),
101
+ NormLayer(out_planes*2), GLU(),
102
+ conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
103
+ NoiseInjection(),
104
+ NormLayer(out_planes*2), GLU()
105
+ )
106
+ return block
107
+
108
+
109
+ class UpBlockBigCond(nn.Module):
110
+ def __init__(self, in_planes, out_planes, z_dim):
111
+ super().__init__()
112
+ self.in_planes = in_planes
113
+ self.out_planes = out_planes
114
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
115
+ self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
116
+ self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
117
+
118
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
119
+ self.bn1 = which_bn(2*out_planes)
120
+ self.bn2 = which_bn(2*out_planes)
121
+ self.act = GLU()
122
+ self.noise = NoiseInjection()
123
+
124
+ def forward(self, x, c):
125
+ # block 1
126
+ x = self.up(x)
127
+ x = self.conv1(x)
128
+ x = self.noise(x)
129
+ x = self.bn1(x, c)
130
+ x = self.act(x)
131
+
132
+ # block 2
133
+ x = self.conv2(x)
134
+ x = self.noise(x)
135
+ x = self.bn2(x, c)
136
+ x = self.act(x)
137
+
138
+ return x
139
+
140
+
141
+ class SEBlock(nn.Module):
142
+ def __init__(self, ch_in, ch_out):
143
+ super().__init__()
144
+ self.main = nn.Sequential(
145
+ nn.AdaptiveAvgPool2d(4),
146
+ conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
147
+ Swish(),
148
+ conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
149
+ nn.Sigmoid(),
150
+ )
151
+
152
+ def forward(self, feat_small, feat_big):
153
+ return feat_big * self.main(feat_small)
154
+
155
+
156
+ ### Downblocks
157
+
158
+
159
+ class SeparableConv2d(nn.Module):
160
+ def __init__(self, in_channels, out_channels, kernel_size, bias=False):
161
+ super(SeparableConv2d, self).__init__()
162
+ self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
163
+ groups=in_channels, bias=bias, padding=1)
164
+ self.pointwise = conv2d(in_channels, out_channels,
165
+ kernel_size=1, bias=bias)
166
+
167
+ def forward(self, x):
168
+ out = self.depthwise(x)
169
+ out = self.pointwise(out)
170
+ return out
171
+
172
+
173
+ class DownBlock(nn.Module):
174
+ def __init__(self, in_planes, out_planes, separable=False):
175
+ super().__init__()
176
+ if not separable:
177
+ self.main = nn.Sequential(
178
+ conv2d(in_planes, out_planes, 4, 2, 1),
179
+ NormLayer(out_planes),
180
+ nn.LeakyReLU(0.2, inplace=True),
181
+ )
182
+ else:
183
+ self.main = nn.Sequential(
184
+ SeparableConv2d(in_planes, out_planes, 3),
185
+ NormLayer(out_planes),
186
+ nn.LeakyReLU(0.2, inplace=True),
187
+ nn.AvgPool2d(2, 2),
188
+ )
189
+
190
+ def forward(self, feat):
191
+ return self.main(feat)
192
+
193
+
194
+ class DownBlockPatch(nn.Module):
195
+ def __init__(self, in_planes, out_planes, separable=False):
196
+ super().__init__()
197
+ self.main = nn.Sequential(
198
+ DownBlock(in_planes, out_planes, separable),
199
+ conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
200
+ NormLayer(out_planes),
201
+ nn.LeakyReLU(0.2, inplace=True),
202
+ )
203
+
204
+ def forward(self, feat):
205
+ return self.main(feat)
206
+
207
+
208
+ ### CSM
209
+
210
+
211
+ class ResidualConvUnit(nn.Module):
212
+ def __init__(self, cin, activation, bn):
213
+ super().__init__()
214
+ self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
215
+ self.skip_add = nn.quantized.FloatFunctional()
216
+
217
+ def forward(self, x):
218
+ return self.skip_add.add(self.conv(x), x)
219
+
220
+
221
+ class FeatureFusionBlock(nn.Module):
222
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
223
+ super().__init__()
224
+
225
+ self.deconv = deconv
226
+ self.align_corners = align_corners
227
+
228
+ self.expand = expand
229
+ out_features = features
230
+ if self.expand==True:
231
+ out_features = features//2
232
+
233
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
234
+ self.skip_add = nn.quantized.FloatFunctional()
235
+
236
+ def forward(self, *xs):
237
+ output = xs[0]
238
+
239
+ if len(xs) == 2:
240
+ output = self.skip_add.add(output, xs[1])
241
+
242
+ output = nn.functional.interpolate(
243
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
244
+ )
245
+
246
+ output = self.out_conv(output)
247
+
248
+ return output
249
+
250
+ class FeatureFusionBlock_V2(nn.Module):
251
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
252
+ super().__init__()
253
+
254
+ self.deconv = deconv
255
+ self.align_corners = align_corners
256
+
257
+ self.expand = expand
258
+ out_features = features
259
+ if self.expand==True:
260
+ out_features = features
261
+
262
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
263
+ self.skip_add = nn.quantized.FloatFunctional()
264
+
265
+ def forward(self, *xs):
266
+ output = xs[0]
267
+
268
+ if len(xs) == 2:
269
+ output = self.skip_add.add(output, xs[1])
270
+
271
+ # output = nn.functional.interpolate(
272
+ # output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
273
+ # )
274
+
275
+ output = self.out_conv(output)
276
+
277
+ return output
278
+ from timm.models.vision_transformer import PatchEmbed, Block
279
+
280
+ class FeatureFusionBlockTrans(nn.Module):
281
+ def __init__(self, features):
282
+ super().__init__()
283
+ self.out_conv = Block(features,num_heads=12)
284
+ self.skip_add = nn.quantized.FloatFunctional()
285
+
286
+ def forward(self, *xs):
287
+ output = xs[0]
288
+
289
+ if len(xs) == 2:
290
+ output = self.skip_add.add(output, xs[1])
291
+ output = self.out_conv(output)
292
+
293
+ return output
294
+
295
+
296
+ ### Misc
297
+
298
+
299
+ class NoiseInjection(nn.Module):
300
+ def __init__(self):
301
+ super().__init__()
302
+ self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
303
+
304
+ def forward(self, feat, noise=None):
305
+ if noise is None:
306
+ batch, _, height, width = feat.shape
307
+ noise = torch.randn(batch, 1, height, width).to(feat.device)
308
+
309
+ return feat + self.weight * noise
310
+
311
+
312
+ class CCBN(nn.Module):
313
+ ''' conditional batchnorm '''
314
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
315
+ super().__init__()
316
+ self.output_size, self.input_size = output_size, input_size
317
+
318
+ # Prepare gain and bias layers
319
+ self.gain = which_linear(input_size, output_size)
320
+ self.bias = which_linear(input_size, output_size)
321
+
322
+ # epsilon to avoid dividing by 0
323
+ self.eps = eps
324
+ # Momentum
325
+ self.momentum = momentum
326
+
327
+ self.register_buffer('stored_mean', torch.zeros(output_size))
328
+ self.register_buffer('stored_var', torch.ones(output_size))
329
+
330
+ def forward(self, x, y):
331
+ # Calculate class-conditional gains and biases
332
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
333
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
334
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
335
+ self.training, 0.1, self.eps)
336
+ return out * gain + bias
337
+
338
+
339
+ class Interpolate(nn.Module):
340
+ """Interpolation module."""
341
+
342
+ def __init__(self, size, mode='bilinear', align_corners=False):
343
+ """Init.
344
+ Args:
345
+ scale_factor (float): scaling
346
+ mode (str): interpolation mode
347
+ """
348
+ super(Interpolate, self).__init__()
349
+
350
+ self.interp = nn.functional.interpolate
351
+ self.size = size
352
+ self.mode = mode
353
+ self.align_corners = align_corners
354
+
355
+ def forward(self, x):
356
+ """Forward pass.
357
+ Args:
358
+ x (tensor): input
359
+ Returns:
360
+ tensor: interpolated data
361
+ """
362
+
363
+ x = self.interp(
364
+ x,
365
+ size=self.size,
366
+ mode=self.mode,
367
+ align_corners=self.align_corners,
368
+ )
369
+
370
+ return x
pg_modules/diffaug.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Differentiable Augmentation for Data-Efficient GAN Training
2
+ # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
+ # https://arxiv.org/pdf/2006.10738
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def DiffAugment(x, policy='', channels_first=True):
10
+ if policy:
11
+ if not channels_first:
12
+ x = x.permute(0, 3, 1, 2)
13
+ for p in policy.split(','):
14
+ for f in AUGMENT_FNS[p]:
15
+ x = f(x)
16
+ if not channels_first:
17
+ x = x.permute(0, 2, 3, 1)
18
+ x = x.contiguous()
19
+ return x
20
+
21
+
22
+ def rand_brightness(x):
23
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
+ return x
25
+
26
+
27
+ def rand_saturation(x):
28
+ x_mean = x.mean(dim=1, keepdim=True)
29
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
+ return x
31
+
32
+
33
+ def rand_contrast(x):
34
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
+ return x
37
+
38
+
39
+ def rand_translation(x, ratio=0.125):
40
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
+ grid_batch, grid_x, grid_y = torch.meshgrid(
44
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
+ )
48
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
+ return x
53
+
54
+
55
+ def rand_cutout(x, ratio=0.2):
56
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
+ grid_batch, grid_x, grid_y = torch.meshgrid(
60
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
+ )
64
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
+ mask[grid_batch, grid_x, grid_y] = 0
68
+ x = x * mask.unsqueeze(1)
69
+ return x
70
+
71
+
72
+ AUGMENT_FNS = {
73
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
74
+ 'translation': [rand_translation],
75
+ 'cutout': [rand_cutout],
76
+ }
pg_modules/discriminator.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Normalize
6
+ import pickle
7
+
8
+ from pg_modules.diffaug import DiffAugment
9
+ from pg_modules.blocks import conv2d, DownBlock, DownBlockPatch
10
+ from pg_modules.projector import F_RandomProj
11
+ from feature_networks.constants import VITS
12
+
13
+ class SingleDisc(nn.Module):
14
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, patch=False):
15
+ super().__init__()
16
+
17
+ # midas channels
18
+ nfc_midas = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
19
+ 256: 32, 512: 16, 1024: 8}
20
+
21
+ # interpolate for start sz that are not powers of two
22
+ if start_sz not in nfc_midas.keys():
23
+ sizes = np.array(list(nfc_midas.keys()))
24
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
25
+ self.start_sz = start_sz
26
+
27
+ # if given ndf, allocate all layers with the same ndf
28
+ if ndf is None:
29
+ nfc = nfc_midas
30
+ else:
31
+ nfc = {k: ndf for k, v in nfc_midas.items()}
32
+
33
+ # for feature map discriminators with nfc not in nfc_midas
34
+ # this is the case for the pretrained backbone (midas.pretrained)
35
+ if nc is not None and head is None:
36
+ nfc[start_sz] = nc
37
+
38
+ layers = []
39
+
40
+ # Head if the initial input is the full modality
41
+ if head:
42
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
43
+ nn.LeakyReLU(0.2, inplace=True)]
44
+
45
+ # Down Blocks
46
+ DB = DownBlockPatch if patch else DownBlock
47
+ while start_sz > end_sz:
48
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
49
+ start_sz = start_sz // 2
50
+
51
+ layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
52
+ self.main = nn.Sequential(*layers)
53
+
54
+ def forward(self, x, c):
55
+ return self.main(x)
56
+
57
+ class MultiScaleD(nn.Module):
58
+ def __init__(
59
+ self,
60
+ channels,
61
+ resolutions,
62
+ num_discs=4,
63
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
64
+ cond=0,
65
+ patch=False,
66
+ **kwargs,
67
+ ):
68
+ super().__init__()
69
+
70
+ assert num_discs in [1, 2, 3, 4, 5]
71
+
72
+ # the first disc is on the lowest level of the backbone
73
+ self.disc_in_channels = channels[:num_discs]
74
+ self.disc_in_res = resolutions[:num_discs]
75
+ Disc = SingleDisc
76
+
77
+ mini_discs = []
78
+ for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
79
+ start_sz = res if not patch else 16
80
+ mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, patch=patch)],
81
+
82
+ self.mini_discs = nn.ModuleDict(mini_discs)
83
+
84
+ def forward(self, features, c, rec=False):
85
+ all_logits = []
86
+ for k, disc in self.mini_discs.items():
87
+ all_logits.append(disc(features[k], c).view(features[k].size(0), -1))
88
+
89
+ all_logits = torch.cat(all_logits, dim=1)
90
+ return all_logits
91
+
92
+ class ProjectedDiscriminator(torch.nn.Module):
93
+ def __init__(
94
+ self,
95
+ backbones,
96
+ diffaug=True,
97
+ interp224=True,
98
+ backbone_kwargs={},
99
+ **kwargs
100
+ ):
101
+ super().__init__()
102
+ self.backbones = backbones
103
+ self.diffaug = diffaug
104
+ self.interp224 = interp224
105
+
106
+ # get backbones and multi-scale discs
107
+ feature_networks, discriminators = [], []
108
+
109
+ for i, bb_name in enumerate(backbones):
110
+
111
+ feat = F_RandomProj(bb_name, **backbone_kwargs)
112
+ disc = MultiScaleD(
113
+ channels=feat.CHANNELS,
114
+ resolutions=feat.RESOLUTIONS,
115
+ **backbone_kwargs,
116
+ )
117
+
118
+ feature_networks.append([bb_name, feat])
119
+ discriminators.append([bb_name, disc])
120
+
121
+ self.feature_networks = nn.ModuleDict(feature_networks)
122
+ self.discriminators = nn.ModuleDict(discriminators)
123
+
124
+ def train(self, mode=True):
125
+ self.feature_networks = self.feature_networks.train(False)
126
+ self.discriminators = self.discriminators.train(mode)
127
+ return self
128
+
129
+ def eval(self):
130
+ return self.train(False)
131
+
132
+ def forward(self, x, c):
133
+ logits = []
134
+ for bb_name, feat in self.feature_networks.items():
135
+
136
+ # apply augmentation (x in [-1, 1])
137
+ x_aug = DiffAugment(x, policy='color,translation,cutout') if self.diffaug else x
138
+
139
+ # transform to [0,1]
140
+ x_aug = x_aug.add(1).div(2)
141
+
142
+ # apply F-specific normalization
143
+ x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
144
+
145
+ # upsample if smaller, downsample if larger + VIT
146
+ if self.interp224 or bb_name in VITS:
147
+ x_n = F.interpolate(x_n, 224, mode='bilinear', align_corners=False)
148
+
149
+ # forward pass
150
+ features = feat(x_n)
151
+ logits += self.discriminators[bb_name](features, c)
152
+
153
+ return logits
pg_modules/networks_fastgan.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py
2
+ #
3
+ # modified by Axel Sauer for "Projected GANs Converge Faster"
4
+ #
5
+ import torch.nn as nn
6
+ from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d)
7
+ import torch
8
+
9
+ def normalize_second_moment(x, dim=1, eps=1e-8):
10
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
11
+
12
+
13
+ class DummyMapping(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, z, c, **kwargs):
18
+ return z.unsqueeze(1) # to fit the StyleGAN API
19
+
20
+
21
+ class FastganSynthesis(nn.Module):
22
+ def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False):
23
+ super().__init__()
24
+ self.img_resolution = img_resolution
25
+ self.z_dim = z_dim
26
+
27
+ # channel multiplier
28
+ nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 224:0.5, 256:0.5,
29
+ 512:0.25, 1024:0.125}
30
+ nfc = {}
31
+ for k, v in nfc_multi.items():
32
+ nfc[k] = int(v*ngf)
33
+
34
+ # layers
35
+ self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
36
+
37
+ UpBlock = UpBlockSmall if lite else UpBlockBig
38
+
39
+ self.feat_8 = UpBlock(nfc[4], nfc[8])
40
+ self.feat_16 = UpBlock(nfc[8], nfc[16])
41
+ self.feat_32 = UpBlock(nfc[16], nfc[32])
42
+ self.feat_64 = UpBlock(nfc[32], nfc[64])
43
+ self.feat_128 = UpBlock(nfc[64], nfc[128])
44
+ self.feat_256 = UpBlock(nfc[128], nfc[256])
45
+
46
+ self.se_64 = SEBlock(nfc[4], nfc[64])
47
+ self.se_128 = SEBlock(nfc[8], nfc[128])
48
+ self.se_256 = SEBlock(nfc[16], nfc[256])
49
+
50
+ self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
51
+
52
+ if img_resolution > 256:
53
+ self.feat_512 = UpBlock(nfc[256], nfc[512])
54
+ self.se_512 = SEBlock(nfc[32], nfc[512])
55
+ if img_resolution > 512:
56
+ self.feat_1024 = UpBlock(nfc[512], nfc[1024])
57
+
58
+ def forward(self, input, c, **kwargs):
59
+ # map noise to hypersphere as in "Progressive Growing of GANS"
60
+ input = normalize_second_moment(input[:, 0])
61
+ feat_4 = self.init(input)
62
+ feat_8 = self.feat_8(feat_4)
63
+ feat_16 = self.feat_16(feat_8)
64
+ feat_32 = self.feat_32(feat_16)
65
+ feat_64 = self.se_64(feat_4, self.feat_64(feat_32))
66
+ feat_128 = self.se_128(feat_8, self.feat_128(feat_64))\
67
+
68
+ if self.img_resolution >= 64:
69
+ feat_last = feat_64
70
+
71
+ if self.img_resolution >= 128:
72
+ feat_last = feat_128
73
+
74
+ if self.img_resolution >= 224:
75
+ feat_last = self.se_256(feat_16, self.feat_256(feat_last))
76
+
77
+ if self.img_resolution >= 512:
78
+ feat_last = self.se_512(feat_32, self.feat_512(feat_last))
79
+
80
+ if self.img_resolution >= 1024:
81
+ feat_last = self.feat_1024(feat_last)
82
+
83
+ return torch.tanh(self.to_big(feat_last))
84
+
85
+
86
+ class FastganSynthesisCond(nn.Module):
87
+ def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False):
88
+ super().__init__()
89
+
90
+ self.z_dim = z_dim
91
+ nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
92
+ 512:0.25, 1024:0.125, 2048:0.125}
93
+ nfc = {}
94
+ for k, v in nfc_multi.items():
95
+ nfc[k] = int(v*ngf)
96
+
97
+ self.img_resolution = img_resolution
98
+
99
+ self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
100
+
101
+ UpBlock = UpBlockSmallCond if lite else UpBlockBigCond
102
+
103
+ self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
104
+ self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
105
+ self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
106
+ self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
107
+ self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim)
108
+ self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim)
109
+
110
+ self.se_64 = SEBlock(nfc[4], nfc[64])
111
+ self.se_128 = SEBlock(nfc[8], nfc[128])
112
+ self.se_256 = SEBlock(nfc[16], nfc[256])
113
+
114
+ self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
115
+
116
+ if img_resolution > 256:
117
+ self.feat_512 = UpBlock(nfc[256], nfc[512])
118
+ self.se_512 = SEBlock(nfc[32], nfc[512])
119
+ if img_resolution > 512:
120
+ self.feat_1024 = UpBlock(nfc[512], nfc[1024])
121
+
122
+ self.embed = nn.Embedding(num_classes, z_dim)
123
+
124
+ def forward(self, input, c, update_emas=False):
125
+
126
+ c = self.embed(c.argmax(1))
127
+
128
+ # map noise to hypersphere as in "Progressive Growing of GANS"
129
+ input = normalize_second_moment(input[:, 0])
130
+
131
+ feat_4 = self.init(input)
132
+ feat_8 = self.feat_8(feat_4, c)
133
+ feat_16 = self.feat_16(feat_8, c)
134
+ feat_32 = self.feat_32(feat_16, c)
135
+ feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c))
136
+ feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))
137
+
138
+ if self.img_resolution >= 128:
139
+ feat_last = feat_128
140
+
141
+ if self.img_resolution >= 256:
142
+ feat_last = self.se_256(feat_16, self.feat_256(feat_last, c))
143
+
144
+ if self.img_resolution >= 512:
145
+ feat_last = self.se_512(feat_32, self.feat_512(feat_last, c))
146
+
147
+ if self.img_resolution >= 1024:
148
+ feat_last = self.feat_1024(feat_last, c)
149
+ return self.to_big(feat_last)
150
+
151
+
152
+ class Generator(nn.Module):
153
+ def __init__(
154
+ self,
155
+ z_dim=256,
156
+ c_dim=0,
157
+ w_dim=0,
158
+ img_resolution=256,
159
+ img_channels=3,
160
+ ngf=128,
161
+ cond=0,
162
+ mapping_kwargs={},
163
+ synthesis_kwargs={}
164
+ ):
165
+ super().__init__()
166
+ self.z_dim = z_dim
167
+ self.c_dim = c_dim
168
+ self.w_dim = w_dim
169
+ self.img_resolution = img_resolution
170
+ self.img_channels = img_channels
171
+
172
+ # Mapping and Synthesis Networks
173
+ self.mapping = DummyMapping() # to fit the StyleGAN API
174
+ Synthesis = FastganSynthesisCond if cond else FastganSynthesis
175
+ self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs)
176
+
177
+ def forward(self, z, c, **kwargs):
178
+ w = self.mapping(z, c)
179
+ img = self.synthesis(w, c)
180
+ return img
pg_modules/networks_stylegan2.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ #
9
+ # modified by Axel Sauer for "Projected GANs Converge Faster"
10
+ #
11
+ import numpy as np
12
+ import torch
13
+ from torch_utils import misc
14
+ from torch_utils import persistence
15
+ from torch_utils.ops import conv2d_resample
16
+ from torch_utils.ops import upfirdn2d
17
+ from torch_utils.ops import bias_act
18
+ from torch_utils.ops import fma
19
+
20
+
21
+ @misc.profiled_function
22
+ def normalize_2nd_moment(x, dim=1, eps=1e-8):
23
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
24
+
25
+
26
+ @misc.profiled_function
27
+ def modulated_conv2d(
28
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
29
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
30
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
31
+ noise = None, # Optional noise tensor to add to the output activations.
32
+ up = 1, # Integer upsampling factor.
33
+ down = 1, # Integer downsampling factor.
34
+ padding = 0, # Padding with respect to the upsampled image.
35
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
36
+ demodulate = True, # Apply weight demodulation?
37
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
38
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
39
+ ):
40
+ batch_size = x.shape[0]
41
+ out_channels, in_channels, kh, kw = weight.shape
42
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
43
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
44
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
45
+
46
+ # Pre-normalize inputs to avoid FP16 overflow.
47
+ if x.dtype == torch.float16 and demodulate:
48
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
49
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
50
+
51
+ # Calculate per-sample weights and demodulation coefficients.
52
+ w = None
53
+ dcoefs = None
54
+ if demodulate or fused_modconv:
55
+ w = weight.unsqueeze(0) # [NOIkk]
56
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
57
+ if demodulate:
58
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
59
+ if demodulate and fused_modconv:
60
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
61
+
62
+ # Execute by scaling the activations before and after the convolution.
63
+ if not fused_modconv:
64
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
65
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
66
+ if demodulate and noise is not None:
67
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
68
+ elif demodulate:
69
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
70
+ elif noise is not None:
71
+ x = x.add_(noise.to(x.dtype))
72
+ return x
73
+
74
+ # Execute as one fused op using grouped convolution.
75
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
76
+ batch_size = int(batch_size)
77
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
78
+ x = x.reshape(1, -1, *x.shape[2:])
79
+ w = w.reshape(-1, in_channels, kh, kw)
80
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
81
+ x = x.reshape(batch_size, -1, *x.shape[2:])
82
+ if noise is not None:
83
+ x = x.add_(noise)
84
+ return x
85
+
86
+
87
+ @persistence.persistent_class
88
+ class FullyConnectedLayer(torch.nn.Module):
89
+ def __init__(self,
90
+ in_features, # Number of input features.
91
+ out_features, # Number of output features.
92
+ bias = True, # Apply additive bias before the activation function?
93
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
94
+ lr_multiplier = 1, # Learning rate multiplier.
95
+ bias_init = 0, # Initial value for the additive bias.
96
+ ):
97
+ super().__init__()
98
+ self.in_features = in_features
99
+ self.out_features = out_features
100
+ self.activation = activation
101
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
102
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
103
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
104
+ self.bias_gain = lr_multiplier
105
+
106
+ def forward(self, x):
107
+ w = self.weight.to(x.dtype) * self.weight_gain
108
+ b = self.bias
109
+ if b is not None:
110
+ b = b.to(x.dtype)
111
+ if self.bias_gain != 1:
112
+ b = b * self.bias_gain
113
+
114
+ if self.activation == 'linear' and b is not None:
115
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
116
+ else:
117
+ x = x.matmul(w.t())
118
+ x = bias_act.bias_act(x, b, act=self.activation)
119
+ return x
120
+
121
+ def extra_repr(self):
122
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
123
+
124
+
125
+ @persistence.persistent_class
126
+ class Conv2dLayer(torch.nn.Module):
127
+ def __init__(self,
128
+ in_channels, # Number of input channels.
129
+ out_channels, # Number of output channels.
130
+ kernel_size, # Width and height of the convolution kernel.
131
+ bias = True, # Apply additive bias before the activation function?
132
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
133
+ up = 1, # Integer upsampling factor.
134
+ down = 1, # Integer downsampling factor.
135
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
136
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
137
+ channels_last = False, # Expect the input to have memory_format=channels_last?
138
+ trainable = True, # Update the weights of this layer during training?
139
+ ):
140
+ super().__init__()
141
+ self.in_channels = in_channels
142
+ self.out_channels = out_channels
143
+ self.activation = activation
144
+ self.up = up
145
+ self.down = down
146
+ self.conv_clamp = conv_clamp
147
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
148
+ self.padding = kernel_size // 2
149
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
150
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
151
+
152
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
153
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
154
+ bias = torch.zeros([out_channels]) if bias else None
155
+ if trainable:
156
+ self.weight = torch.nn.Parameter(weight)
157
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
158
+ else:
159
+ self.register_buffer('weight', weight)
160
+ if bias is not None:
161
+ self.register_buffer('bias', bias)
162
+ else:
163
+ self.bias = None
164
+
165
+ def forward(self, x, gain=1):
166
+ w = self.weight * self.weight_gain
167
+ b = self.bias.to(x.dtype) if self.bias is not None else None
168
+ flip_weight = (self.up == 1) # slightly faster
169
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
170
+
171
+ act_gain = self.act_gain * gain
172
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
173
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
174
+ return x
175
+
176
+ def extra_repr(self):
177
+ return ' '.join([
178
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
179
+ f'up={self.up}, down={self.down}'])
180
+
181
+
182
+ @persistence.persistent_class
183
+ class MappingNetwork(torch.nn.Module):
184
+ def __init__(self,
185
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
186
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
187
+ w_dim, # Intermediate latent (W) dimensionality.
188
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
189
+ num_layers = 8, # Number of mapping layers.
190
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
191
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
192
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
193
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
194
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
195
+ ):
196
+ super().__init__()
197
+ self.z_dim = z_dim
198
+ self.c_dim = c_dim
199
+ self.w_dim = w_dim
200
+ self.num_ws = num_ws
201
+ self.num_layers = num_layers
202
+ self.w_avg_beta = w_avg_beta
203
+
204
+ if embed_features is None:
205
+ embed_features = w_dim
206
+ if c_dim == 0:
207
+ embed_features = 0
208
+ if layer_features is None:
209
+ layer_features = w_dim
210
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
211
+
212
+ if c_dim > 0:
213
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
214
+ for idx in range(num_layers):
215
+ in_features = features_list[idx]
216
+ out_features = features_list[idx + 1]
217
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
218
+ setattr(self, f'fc{idx}', layer)
219
+
220
+ if num_ws is not None and w_avg_beta is not None:
221
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
222
+
223
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
224
+ # Embed, normalize, and concat inputs.
225
+ x = None
226
+ with torch.autograd.profiler.record_function('input'):
227
+ if self.z_dim > 0:
228
+ misc.assert_shape(z, [None, self.z_dim])
229
+ x = normalize_2nd_moment(z.to(torch.float32))
230
+ if self.c_dim > 0:
231
+ misc.assert_shape(c, [None, self.c_dim])
232
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
233
+ x = torch.cat([x, y], dim=1) if x is not None else y
234
+
235
+ # Main layers.
236
+ for idx in range(self.num_layers):
237
+ layer = getattr(self, f'fc{idx}')
238
+ x = layer(x)
239
+
240
+ # Update moving average of W.
241
+ if update_emas and self.w_avg_beta is not None:
242
+ with torch.autograd.profiler.record_function('update_w_avg'):
243
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
244
+
245
+ # Broadcast.
246
+ if self.num_ws is not None:
247
+ with torch.autograd.profiler.record_function('broadcast'):
248
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
249
+
250
+ # Apply truncation.
251
+ if truncation_psi != 1:
252
+ with torch.autograd.profiler.record_function('truncate'):
253
+ assert self.w_avg_beta is not None
254
+ if self.num_ws is None or truncation_cutoff is None:
255
+ x = self.w_avg.lerp(x, truncation_psi)
256
+ else:
257
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
258
+ return x
259
+
260
+ def extra_repr(self):
261
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
262
+
263
+
264
+ @persistence.persistent_class
265
+ class SynthesisLayer(torch.nn.Module):
266
+ def __init__(self,
267
+ in_channels, # Number of input channels.
268
+ out_channels, # Number of output channels.
269
+ w_dim, # Intermediate latent (W) dimensionality.
270
+ resolution, # Resolution of this layer.
271
+ kernel_size = 3, # Convolution kernel size.
272
+ up = 1, # Integer upsampling factor.
273
+ use_noise = True, # Enable noise input?
274
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
275
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
276
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
277
+ channels_last = False, # Use channels_last format for the weights?
278
+ ):
279
+ super().__init__()
280
+ self.in_channels = in_channels
281
+ self.out_channels = out_channels
282
+ self.w_dim = w_dim
283
+ self.resolution = resolution
284
+ self.up = up
285
+ self.use_noise = use_noise
286
+ self.activation = activation
287
+ self.conv_clamp = conv_clamp
288
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
289
+ self.padding = kernel_size // 2
290
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
291
+
292
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
293
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
294
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
295
+ if use_noise:
296
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
297
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
298
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
299
+
300
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
301
+ assert noise_mode in ['random', 'const', 'none']
302
+ in_resolution = self.resolution // self.up
303
+ misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
304
+ styles = self.affine(w)
305
+
306
+ noise = None
307
+ if self.use_noise and noise_mode == 'random':
308
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
309
+ if self.use_noise and noise_mode == 'const':
310
+ noise = self.noise_const * self.noise_strength
311
+
312
+ flip_weight = (self.up == 1) # slightly faster
313
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
314
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
315
+
316
+ act_gain = self.act_gain * gain
317
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
318
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
319
+ return x
320
+
321
+ def extra_repr(self):
322
+ return ' '.join([
323
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
324
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
325
+
326
+
327
+ @persistence.persistent_class
328
+ class ToRGBLayer(torch.nn.Module):
329
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
330
+ super().__init__()
331
+ self.in_channels = in_channels
332
+ self.out_channels = out_channels
333
+ self.w_dim = w_dim
334
+ self.conv_clamp = conv_clamp
335
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
336
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
337
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
338
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
339
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
340
+
341
+ def forward(self, x, w, fused_modconv=True):
342
+ styles = self.affine(w) * self.weight_gain
343
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
344
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
345
+ return x
346
+
347
+ def extra_repr(self):
348
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
349
+
350
+
351
+ @persistence.persistent_class
352
+ class SynthesisBlock(torch.nn.Module):
353
+ def __init__(self,
354
+ in_channels, # Number of input channels, 0 = first block.
355
+ out_channels, # Number of output channels.
356
+ w_dim, # Intermediate latent (W) dimensionality.
357
+ resolution, # Resolution of this block.
358
+ img_channels, # Number of output color channels.
359
+ is_last, # Is this the last block?
360
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
361
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
362
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
363
+ use_fp16 = False, # Use FP16 for this block?
364
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
365
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
366
+ **layer_kwargs, # Arguments for SynthesisLayer.
367
+ ):
368
+ assert architecture in ['orig', 'skip', 'resnet']
369
+ super().__init__()
370
+ self.in_channels = in_channels
371
+ self.w_dim = w_dim
372
+ self.resolution = resolution
373
+ self.img_channels = img_channels
374
+ self.is_last = is_last
375
+ self.architecture = architecture
376
+ self.use_fp16 = use_fp16
377
+ self.channels_last = (use_fp16 and fp16_channels_last)
378
+ self.fused_modconv_default = fused_modconv_default
379
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
380
+ self.num_conv = 0
381
+ self.num_torgb = 0
382
+
383
+ if in_channels == 0:
384
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
385
+
386
+ if in_channels != 0:
387
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
388
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
389
+ self.num_conv += 1
390
+
391
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
392
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
393
+ self.num_conv += 1
394
+
395
+ if is_last or architecture == 'skip':
396
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
397
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
398
+ self.num_torgb += 1
399
+
400
+ if in_channels != 0 and architecture == 'resnet':
401
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
402
+ resample_filter=resample_filter, channels_last=self.channels_last)
403
+
404
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
405
+ _ = update_emas # unused
406
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
407
+ w_iter = iter(ws.unbind(dim=1))
408
+ if ws.device.type != 'cuda':
409
+ force_fp32 = True
410
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
411
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
412
+ if fused_modconv is None:
413
+ fused_modconv = self.fused_modconv_default
414
+ if fused_modconv == 'inference_only':
415
+ fused_modconv = (not self.training)
416
+
417
+ # Input.
418
+ if self.in_channels == 0:
419
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
420
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
421
+ else:
422
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
423
+ x = x.to(dtype=dtype, memory_format=memory_format)
424
+
425
+ # Main layers.
426
+ if self.in_channels == 0:
427
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
428
+ elif self.architecture == 'resnet':
429
+ y = self.skip(x, gain=np.sqrt(0.5))
430
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
431
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
432
+ x = y.add_(x)
433
+ else:
434
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
435
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
436
+
437
+ # ToRGB.
438
+ if img is not None:
439
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
440
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
441
+ if self.is_last or self.architecture == 'skip':
442
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
443
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
444
+ img = img.add_(y) if img is not None else y
445
+
446
+ assert x.dtype == dtype
447
+ assert img is None or img.dtype == torch.float32
448
+ return x, img
449
+
450
+ def extra_repr(self):
451
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
452
+
453
+
454
+ @persistence.persistent_class
455
+ class SynthesisNetwork(torch.nn.Module):
456
+ def __init__(self,
457
+ w_dim, # Intermediate latent (W) dimensionality.
458
+ img_resolution, # Output image resolution.
459
+ img_channels, # Number of color channels.
460
+ channel_base = 32768, # Overall multiplier for the number of channels.
461
+ channel_max = 512, # Maximum number of channels in any layer.
462
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
463
+ **block_kwargs, # Arguments for SynthesisBlock.
464
+ ):
465
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
466
+ super().__init__()
467
+ self.w_dim = w_dim
468
+ self.img_resolution = img_resolution
469
+ self.img_resolution_log2 = int(np.log2(img_resolution))
470
+ self.img_channels = img_channels
471
+ self.num_fp16_res = num_fp16_res
472
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
473
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
474
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
475
+
476
+ self.num_ws = 0
477
+ for res in self.block_resolutions:
478
+ in_channels = channels_dict[res // 2] if res > 4 else 0
479
+ out_channels = channels_dict[res]
480
+ use_fp16 = (res >= fp16_resolution)
481
+ is_last = (res == self.img_resolution)
482
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
483
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
484
+ self.num_ws += block.num_conv
485
+ if is_last:
486
+ self.num_ws += block.num_torgb
487
+ setattr(self, f'b{res}', block)
488
+
489
+ def forward(self, ws, c=None, **block_kwargs):
490
+ block_ws = []
491
+ with torch.autograd.profiler.record_function('split_ws'):
492
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
493
+ ws = ws.to(torch.float32)
494
+ w_idx = 0
495
+ for res in self.block_resolutions:
496
+ block = getattr(self, f'b{res}')
497
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
498
+ w_idx += block.num_conv
499
+
500
+ x = img = None
501
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
502
+ block = getattr(self, f'b{res}')
503
+ x, img = block(x, img, cur_ws, **block_kwargs)
504
+ return img
505
+
506
+ def extra_repr(self):
507
+ return ' '.join([
508
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
509
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
510
+ f'num_fp16_res={self.num_fp16_res:d}'])
511
+
512
+
513
+ @persistence.persistent_class
514
+ class Generator(torch.nn.Module):
515
+ def __init__(self,
516
+ z_dim, # Input latent (Z) dimensionality.
517
+ c_dim, # Conditioning label (C) dimensionality.
518
+ w_dim, # Intermediate latent (W) dimensionality.
519
+ img_resolution, # Output resolution.
520
+ img_channels, # Number of output color channels.
521
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
522
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
523
+ ):
524
+ super().__init__()
525
+ self.z_dim = z_dim
526
+ self.c_dim = c_dim
527
+ self.w_dim = w_dim
528
+ self.img_resolution = img_resolution
529
+ self.img_channels = img_channels
530
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
531
+ self.num_ws = self.synthesis.num_ws
532
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
533
+
534
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
535
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
536
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
537
+ return img
pg_modules/projector.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from feature_networks.vit import forward_vit
5
+ from feature_networks.pretrained_builder import _make_pretrained
6
+ from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS
7
+ from pg_modules.blocks import FeatureFusionBlock
8
+
9
+ def get_backbone_normstats(backbone):
10
+ if backbone in NORMALIZED_INCEPTION:
11
+ return {
12
+ 'mean': [0.5, 0.5, 0.5],
13
+ 'std': [0.5, 0.5, 0.5],
14
+ }
15
+
16
+ elif backbone in NORMALIZED_IMAGENET:
17
+ return {
18
+ 'mean': [0.485, 0.456, 0.406],
19
+ 'std': [0.229, 0.224, 0.225],
20
+ }
21
+
22
+ elif backbone in NORMALIZED_CLIP:
23
+ return {
24
+ 'mean': [0.48145466, 0.4578275, 0.40821073],
25
+ 'std': [0.26862954, 0.26130258, 0.27577711],
26
+ }
27
+
28
+ else:
29
+ raise NotImplementedError
30
+
31
+ def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
32
+ # shapes
33
+ out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
34
+
35
+ scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
36
+ scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
37
+ scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
38
+ scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
39
+
40
+ scratch.CHANNELS = out_channels
41
+
42
+ return scratch
43
+
44
+ def _make_scratch_csm(scratch, in_channels, cout, expand):
45
+ scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
46
+ scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
47
+ scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
48
+ scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
49
+
50
+ # last refinenet does not expand to save channels in higher dimensions
51
+ scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
52
+
53
+ return scratch
54
+
55
+ def _make_projector(im_res, backbone, cout, proj_type, expand=False):
56
+ assert proj_type in [0, 1, 2], "Invalid projection type"
57
+
58
+ ### Build pretrained feature network
59
+ pretrained = _make_pretrained(backbone)
60
+
61
+ # Following Projected GAN
62
+ im_res = 256
63
+ pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
64
+
65
+ if proj_type == 0: return pretrained, None
66
+
67
+ ### Build CCM
68
+ scratch = nn.Module()
69
+ scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
70
+
71
+ pretrained.CHANNELS = scratch.CHANNELS
72
+
73
+ if proj_type == 1: return pretrained, scratch
74
+
75
+ ### build CSM
76
+ scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
77
+
78
+ # CSM upsamples x2 so the feature map resolution doubles
79
+ pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
80
+ pretrained.CHANNELS = scratch.CHANNELS
81
+
82
+ return pretrained, scratch
83
+
84
+ class F_Identity(nn.Module):
85
+ def forward(self, x):
86
+ return x
87
+
88
+ class F_RandomProj(nn.Module):
89
+ def __init__(
90
+ self,
91
+ backbone="tf_efficientnet_lite3",
92
+ im_res=256,
93
+ cout=64,
94
+ expand=True,
95
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
96
+ **kwargs,
97
+ ):
98
+ super().__init__()
99
+ self.proj_type = proj_type
100
+ self.backbone = backbone
101
+ self.cout = cout
102
+ self.expand = expand
103
+ self.normstats = get_backbone_normstats(backbone)
104
+
105
+ # build pretrained feature network and random decoder (scratch)
106
+ self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout,
107
+ proj_type=self.proj_type, expand=self.expand)
108
+ self.CHANNELS = self.pretrained.CHANNELS
109
+ self.RESOLUTIONS = self.pretrained.RESOLUTIONS
110
+
111
+ def forward(self, x):
112
+ # predict feature maps
113
+ if self.backbone in VITS:
114
+ out0, out1, out2, out3 = forward_vit(self.pretrained, x)
115
+ else:
116
+ out0 = self.pretrained.layer0(x)
117
+ out1 = self.pretrained.layer1(out0)
118
+ out2 = self.pretrained.layer2(out1)
119
+ out3 = self.pretrained.layer3(out2)
120
+
121
+ # start enumerating at the lowest layer (this is where we put the first discriminator)
122
+ out = {
123
+ '0': out0,
124
+ '1': out1,
125
+ '2': out2,
126
+ '3': out3,
127
+ }
128
+
129
+ if self.proj_type == 0: return out
130
+
131
+ out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
132
+ out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
133
+ out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
134
+ out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
135
+
136
+ out = {
137
+ '0': out0_channel_mixed,
138
+ '1': out1_channel_mixed,
139
+ '2': out2_channel_mixed,
140
+ '3': out3_channel_mixed,
141
+ }
142
+
143
+ if self.proj_type == 1: return out
144
+
145
+ # from bottom to top
146
+ out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
147
+ out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
148
+ out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
149
+ out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
150
+
151
+ out = {
152
+ '0': out0_scale_mixed,
153
+ '1': out1_scale_mixed,
154
+ '2': out2_scale_mixed,
155
+ '3': out3_scale_mixed,
156
+ }
157
+
158
+ return out