AnishaNaik03 commited on
Commit
c96c295
·
1 Parent(s): 1bc65a2

Upload legacy.py

Browse files
Files changed (1) hide show
  1. legacy.py +320 -0
legacy.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # Add missing fields.
32
+ if 'training_set_kwargs' not in data:
33
+ data['training_set_kwargs'] = None
34
+ if 'augment_pipe' not in data:
35
+ data['augment_pipe'] = None
36
+
37
+ # Validate contents.
38
+ assert isinstance(data['G'], torch.nn.Module)
39
+ assert isinstance(data['D'], torch.nn.Module)
40
+ assert isinstance(data['G_ema'], torch.nn.Module)
41
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
+
44
+ # Force FP16.
45
+ if force_fp16:
46
+ for key in ['G', 'D', 'G_ema']:
47
+ old = data[key]
48
+ kwargs = copy.deepcopy(old.init_kwargs)
49
+ if key.startswith('G'):
50
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
+ kwargs.synthesis_kwargs.num_fp16_res = 4
52
+ kwargs.synthesis_kwargs.conv_clamp = 256
53
+ if key.startswith('D'):
54
+ kwargs.num_fp16_res = 4
55
+ kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ mapping_kwargs = dnnlib.EasyDict(
129
+ num_layers = kwarg('mapping_layers', 8),
130
+ embed_features = kwarg('label_fmaps', None),
131
+ layer_features = kwarg('mapping_fmaps', None),
132
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
133
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
134
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135
+ ),
136
+ synthesis_kwargs = dnnlib.EasyDict(
137
+ channel_base = kwarg('fmap_base', 16384) * 2,
138
+ channel_max = kwarg('fmap_max', 512),
139
+ num_fp16_res = kwarg('num_fp16_res', 0),
140
+ conv_clamp = kwarg('conv_clamp', None),
141
+ architecture = kwarg('architecture', 'skip'),
142
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143
+ use_noise = kwarg('use_noise', True),
144
+ activation = kwarg('nonlinearity', 'lrelu'),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ from style_gan.training import networks
169
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
170
+ # pylint: disable=unnecessary-lambda
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ )
203
+ return G
204
+
205
+ #----------------------------------------------------------------------------
206
+
207
+ def convert_tf_discriminator(tf_D):
208
+ if tf_D.version < 4:
209
+ raise ValueError('TensorFlow pickle version too low')
210
+
211
+ # Collect kwargs.
212
+ tf_kwargs = tf_D.static_kwargs
213
+ known_kwargs = set()
214
+ def kwarg(tf_name, default=None):
215
+ known_kwargs.add(tf_name)
216
+ return tf_kwargs.get(tf_name, default)
217
+
218
+ # Convert kwargs.
219
+ kwargs = dnnlib.EasyDict(
220
+ c_dim = kwarg('label_size', 0),
221
+ img_resolution = kwarg('resolution', 1024),
222
+ img_channels = kwarg('num_channels', 3),
223
+ architecture = kwarg('architecture', 'resnet'),
224
+ channel_base = kwarg('fmap_base', 16384) * 2,
225
+ channel_max = kwarg('fmap_max', 512),
226
+ num_fp16_res = kwarg('num_fp16_res', 0),
227
+ conv_clamp = kwarg('conv_clamp', None),
228
+ cmap_dim = kwarg('mapping_fmaps', None),
229
+ block_kwargs = dnnlib.EasyDict(
230
+ activation = kwarg('nonlinearity', 'lrelu'),
231
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232
+ freeze_layers = kwarg('freeze_layers', 0),
233
+ ),
234
+ mapping_kwargs = dnnlib.EasyDict(
235
+ num_layers = kwarg('mapping_layers', 0),
236
+ embed_features = kwarg('mapping_fmaps', None),
237
+ layer_features = kwarg('mapping_fmaps', None),
238
+ activation = kwarg('nonlinearity', 'lrelu'),
239
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
240
+ ),
241
+ epilogue_kwargs = dnnlib.EasyDict(
242
+ mbstd_group_size = kwarg('mbstd_group_size', None),
243
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
244
+ activation = kwarg('nonlinearity', 'lrelu'),
245
+ ),
246
+ )
247
+
248
+ # Check for unknown kwargs.
249
+ kwarg('structure')
250
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251
+ if len(unknown_kwargs) > 0:
252
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253
+
254
+ # Collect params.
255
+ tf_params = _collect_tf_params(tf_D)
256
+ for name, value in list(tf_params.items()):
257
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258
+ if match:
259
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
260
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261
+ kwargs.architecture = 'orig'
262
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263
+
264
+ # Convert params.
265
+ from style_gan.training import networks
266
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267
+ # pylint: disable=unnecessary-lambda
268
+ _populate_module_params(D,
269
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284
+ r'.*\.resample_filter', None,
285
+ )
286
+ return D
287
+
288
+ #----------------------------------------------------------------------------
289
+
290
+ @click.command()
291
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294
+ def convert_network_pickle(source, dest, force_fp16):
295
+ """Convert legacy network pickle into the native PyTorch format.
296
+
297
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299
+
300
+ Example:
301
+
302
+ \b
303
+ python legacy.py \\
304
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305
+ --dest=stylegan2-cat-config-f.pkl
306
+ """
307
+ print(f'Loading "{source}"...')
308
+ with dnnlib.util.open_url(source) as f:
309
+ data = load_network_pkl(f, force_fp16=force_fp16)
310
+ print(f'Saving "{dest}"...')
311
+ with open(dest, 'wb') as f:
312
+ pickle.dump(data, f)
313
+ print('Done.')
314
+
315
+ #----------------------------------------------------------------------------
316
+
317
+ if __name__ == "__main__":
318
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
319
+
320
+ #----------------------------------------------------------------------------