tranmc commited on
Commit
e0fadb3
·
1 Parent(s): 0f33176

Delete convert_original_stable_diffusion_to_diffusers.py.1

Browse files
convert_original_stable_diffusion_to_diffusers.py.1 DELETED
@@ -1,752 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Conversion script for the LDM checkpoints. """
16
-
17
- import argparse
18
- import os
19
-
20
- import torch
21
-
22
-
23
- try:
24
- from omegaconf import OmegaConf
25
- except ImportError:
26
- raise ImportError(
27
- "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
28
- )
29
-
30
- from diffusers import (
31
- AutoencoderKL,
32
- DDIMScheduler,
33
- DPMSolverMultistepScheduler,
34
- EulerAncestralDiscreteScheduler,
35
- EulerDiscreteScheduler,
36
- LDMTextToImagePipeline,
37
- LMSDiscreteScheduler,
38
- PNDMScheduler,
39
- StableDiffusionPipeline,
40
- UNet2DConditionModel,
41
- )
42
- from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
43
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
44
- from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
45
-
46
-
47
- def shave_segments(path, n_shave_prefix_segments=1):
48
- """
49
- Removes segments. Positive values shave the first segments, negative shave the last segments.
50
- """
51
- if n_shave_prefix_segments >= 0:
52
- return ".".join(path.split(".")[n_shave_prefix_segments:])
53
- else:
54
- return ".".join(path.split(".")[:n_shave_prefix_segments])
55
-
56
-
57
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
- """
59
- Updates paths inside resnets to the new naming scheme (local renaming)
60
- """
61
- mapping = []
62
- for old_item in old_list:
63
- new_item = old_item.replace("in_layers.0", "norm1")
64
- new_item = new_item.replace("in_layers.2", "conv1")
65
-
66
- new_item = new_item.replace("out_layers.0", "norm2")
67
- new_item = new_item.replace("out_layers.3", "conv2")
68
-
69
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
- new_item = new_item.replace("skip_connection", "conv_shortcut")
71
-
72
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
-
74
- mapping.append({"old": old_item, "new": new_item})
75
-
76
- return mapping
77
-
78
-
79
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
- """
81
- Updates paths inside resnets to the new naming scheme (local renaming)
82
- """
83
- mapping = []
84
- for old_item in old_list:
85
- new_item = old_item
86
-
87
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
-
90
- mapping.append({"old": old_item, "new": new_item})
91
-
92
- return mapping
93
-
94
-
95
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
- """
97
- Updates paths inside attentions to the new naming scheme (local renaming)
98
- """
99
- mapping = []
100
- for old_item in old_list:
101
- new_item = old_item
102
-
103
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
-
106
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
-
109
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
-
111
- mapping.append({"old": old_item, "new": new_item})
112
-
113
- return mapping
114
-
115
-
116
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
- """
118
- Updates paths inside attentions to the new naming scheme (local renaming)
119
- """
120
- mapping = []
121
- for old_item in old_list:
122
- new_item = old_item
123
-
124
- new_item = new_item.replace("norm.weight", "group_norm.weight")
125
- new_item = new_item.replace("norm.bias", "group_norm.bias")
126
-
127
- new_item = new_item.replace("q.weight", "query.weight")
128
- new_item = new_item.replace("q.bias", "query.bias")
129
-
130
- new_item = new_item.replace("k.weight", "key.weight")
131
- new_item = new_item.replace("k.bias", "key.bias")
132
-
133
- new_item = new_item.replace("v.weight", "value.weight")
134
- new_item = new_item.replace("v.bias", "value.bias")
135
-
136
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
-
139
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
-
141
- mapping.append({"old": old_item, "new": new_item})
142
-
143
- return mapping
144
-
145
-
146
- def assign_to_checkpoint(
147
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
- ):
149
- """
150
- This does the final conversion step: take locally converted weights and apply a global renaming
151
- to them. It splits attention layers, and takes into account additional replacements
152
- that may arise.
153
-
154
- Assigns the weights to the new checkpoint.
155
- """
156
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
-
158
- # Splits the attention layers into three variables.
159
- if attention_paths_to_split is not None:
160
- for path, path_map in attention_paths_to_split.items():
161
- old_tensor = old_checkpoint[path]
162
- channels = old_tensor.shape[0] // 3
163
-
164
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
-
166
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
-
168
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
-
171
- checkpoint[path_map["query"]] = query.reshape(target_shape)
172
- checkpoint[path_map["key"]] = key.reshape(target_shape)
173
- checkpoint[path_map["value"]] = value.reshape(target_shape)
174
-
175
- for path in paths:
176
- new_path = path["new"]
177
-
178
- # These have already been assigned
179
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
- continue
181
-
182
- # Global renaming happens here
183
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
-
187
- if additional_replacements is not None:
188
- for replacement in additional_replacements:
189
- new_path = new_path.replace(replacement["old"], replacement["new"])
190
-
191
- # proj_attn.weight has to be converted from conv 1D to linear
192
- if "proj_attn.weight" in new_path:
193
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
- else:
195
- checkpoint[new_path] = old_checkpoint[path["old"]]
196
-
197
-
198
- def conv_attn_to_linear(checkpoint):
199
- keys = list(checkpoint.keys())
200
- attn_keys = ["query.weight", "key.weight", "value.weight"]
201
- for key in keys:
202
- if ".".join(key.split(".")[-2:]) in attn_keys:
203
- if checkpoint[key].ndim > 2:
204
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
- elif "proj_attn.weight" in key:
206
- if checkpoint[key].ndim > 2:
207
- checkpoint[key] = checkpoint[key][:, :, 0]
208
-
209
-
210
- def create_unet_diffusers_config(original_config):
211
- """
212
- Creates a config for the diffusers based on the config of the LDM model.
213
- """
214
- model_params = original_config.model.params
215
- unet_params = original_config.model.params.unet_config.params
216
-
217
- block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
218
-
219
- down_block_types = []
220
- resolution = 1
221
- for i in range(len(block_out_channels)):
222
- block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
223
- down_block_types.append(block_type)
224
- if i != len(block_out_channels) - 1:
225
- resolution *= 2
226
-
227
- up_block_types = []
228
- for i in range(len(block_out_channels)):
229
- block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
230
- up_block_types.append(block_type)
231
- resolution //= 2
232
-
233
- config = dict(
234
- sample_size=model_params.image_size,
235
- in_channels=unet_params.in_channels,
236
- out_channels=unet_params.out_channels,
237
- down_block_types=tuple(down_block_types),
238
- up_block_types=tuple(up_block_types),
239
- block_out_channels=tuple(block_out_channels),
240
- layers_per_block=unet_params.num_res_blocks,
241
- cross_attention_dim=unet_params.context_dim,
242
- attention_head_dim=unet_params.num_heads,
243
- )
244
-
245
- return config
246
-
247
-
248
- def create_vae_diffusers_config(original_config):
249
- """
250
- Creates a config for the diffusers based on the config of the LDM model.
251
- """
252
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
253
- _ = original_config.model.params.first_stage_config.params.embed_dim
254
-
255
- block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
256
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
257
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
258
-
259
- config = dict(
260
- sample_size=vae_params.resolution,
261
- in_channels=vae_params.in_channels,
262
- out_channels=vae_params.out_ch,
263
- down_block_types=tuple(down_block_types),
264
- up_block_types=tuple(up_block_types),
265
- block_out_channels=tuple(block_out_channels),
266
- latent_channels=vae_params.z_channels,
267
- layers_per_block=vae_params.num_res_blocks,
268
- )
269
- return config
270
-
271
-
272
- def create_diffusers_schedular(original_config):
273
- schedular = DDIMScheduler(
274
- num_train_timesteps=original_config.model.params.timesteps,
275
- beta_start=original_config.model.params.linear_start,
276
- beta_end=original_config.model.params.linear_end,
277
- beta_schedule="scaled_linear",
278
- )
279
- return schedular
280
-
281
-
282
- def create_ldm_bert_config(original_config):
283
- bert_params = original_config.model.parms.cond_stage_config.params
284
- config = LDMBertConfig(
285
- d_model=bert_params.n_embed,
286
- encoder_layers=bert_params.n_layer,
287
- encoder_ffn_dim=bert_params.n_embed * 4,
288
- )
289
- return config
290
-
291
-
292
- def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
293
- """
294
- Takes a state dict and a config, and returns a converted checkpoint.
295
- """
296
-
297
- # extract state_dict for UNet
298
- unet_state_dict = {}
299
- keys = list(checkpoint.keys())
300
-
301
- unet_key = "model.diffusion_model."
302
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
303
- if sum(k.startswith("model_ema") for k in keys) > 100:
304
- print(f"Checkpoint {path} has both EMA and non-EMA weights.")
305
- if extract_ema:
306
- print(
307
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
308
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
309
- )
310
- for key in keys:
311
- if key.startswith("model.diffusion_model"):
312
- flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
313
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
314
- else:
315
- print(
316
- "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
317
- " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
318
- )
319
-
320
- for key in keys:
321
- if key.startswith(unet_key):
322
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
323
-
324
- new_checkpoint = {}
325
-
326
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
327
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
328
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
329
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
330
-
331
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
332
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
333
-
334
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
335
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
336
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
337
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
338
-
339
- # Retrieves the keys for the input blocks only
340
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
341
- input_blocks = {
342
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
343
- for layer_id in range(num_input_blocks)
344
- }
345
-
346
- # Retrieves the keys for the middle blocks only
347
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
348
- middle_blocks = {
349
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
350
- for layer_id in range(num_middle_blocks)
351
- }
352
-
353
- # Retrieves the keys for the output blocks only
354
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
355
- output_blocks = {
356
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
357
- for layer_id in range(num_output_blocks)
358
- }
359
-
360
- for i in range(1, num_input_blocks):
361
- block_id = (i - 1) // (config["layers_per_block"] + 1)
362
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
363
-
364
- resnets = [
365
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
366
- ]
367
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
368
-
369
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
370
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
371
- f"input_blocks.{i}.0.op.weight"
372
- )
373
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
374
- f"input_blocks.{i}.0.op.bias"
375
- )
376
-
377
- paths = renew_resnet_paths(resnets)
378
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
379
- assign_to_checkpoint(
380
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
381
- )
382
-
383
- if len(attentions):
384
- paths = renew_attention_paths(attentions)
385
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
386
- assign_to_checkpoint(
387
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
388
- )
389
-
390
- resnet_0 = middle_blocks[0]
391
- attentions = middle_blocks[1]
392
- resnet_1 = middle_blocks[2]
393
-
394
- resnet_0_paths = renew_resnet_paths(resnet_0)
395
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
396
-
397
- resnet_1_paths = renew_resnet_paths(resnet_1)
398
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
399
-
400
- attentions_paths = renew_attention_paths(attentions)
401
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
402
- assign_to_checkpoint(
403
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
404
- )
405
-
406
- for i in range(num_output_blocks):
407
- block_id = i // (config["layers_per_block"] + 1)
408
- layer_in_block_id = i % (config["layers_per_block"] + 1)
409
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
410
- output_block_list = {}
411
-
412
- for layer in output_block_layers:
413
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
414
- if layer_id in output_block_list:
415
- output_block_list[layer_id].append(layer_name)
416
- else:
417
- output_block_list[layer_id] = [layer_name]
418
-
419
- if len(output_block_list) > 1:
420
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
421
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
422
-
423
- resnet_0_paths = renew_resnet_paths(resnets)
424
- paths = renew_resnet_paths(resnets)
425
-
426
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
427
- assign_to_checkpoint(
428
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
429
- )
430
-
431
- if ["conv.weight", "conv.bias"] in output_block_list.values():
432
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
433
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
434
- f"output_blocks.{i}.{index}.conv.weight"
435
- ]
436
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
437
- f"output_blocks.{i}.{index}.conv.bias"
438
- ]
439
-
440
- # Clear attentions as they have been attributed above.
441
- if len(attentions) == 2:
442
- attentions = []
443
-
444
- if len(attentions):
445
- paths = renew_attention_paths(attentions)
446
- meta_path = {
447
- "old": f"output_blocks.{i}.1",
448
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
449
- }
450
- assign_to_checkpoint(
451
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
452
- )
453
- else:
454
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
455
- for path in resnet_0_paths:
456
- old_path = ".".join(["output_blocks", str(i), path["old"]])
457
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
458
-
459
- new_checkpoint[new_path] = unet_state_dict[old_path]
460
-
461
- return new_checkpoint
462
-
463
-
464
- def convert_ldm_vae_checkpoint(checkpoint, config):
465
- # extract state dict for VAE
466
- vae_state_dict = {}
467
- vae_key = "first_stage_model."
468
- keys = list(checkpoint.keys())
469
- for key in keys:
470
- if key.startswith(vae_key):
471
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
472
-
473
- new_checkpoint = {}
474
-
475
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
476
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
477
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
478
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
479
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
480
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
481
-
482
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
483
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
484
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
485
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
486
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
487
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
488
-
489
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
490
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
491
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
492
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
493
-
494
- # Retrieves the keys for the encoder down blocks only
495
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
496
- down_blocks = {
497
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
498
- }
499
-
500
- # Retrieves the keys for the decoder up blocks only
501
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
502
- up_blocks = {
503
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
504
- }
505
-
506
- for i in range(num_down_blocks):
507
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
508
-
509
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
510
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
511
- f"encoder.down.{i}.downsample.conv.weight"
512
- )
513
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
514
- f"encoder.down.{i}.downsample.conv.bias"
515
- )
516
-
517
- paths = renew_vae_resnet_paths(resnets)
518
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
519
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
520
-
521
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
522
- num_mid_res_blocks = 2
523
- for i in range(1, num_mid_res_blocks + 1):
524
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
525
-
526
- paths = renew_vae_resnet_paths(resnets)
527
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
528
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
529
-
530
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
531
- paths = renew_vae_attention_paths(mid_attentions)
532
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
533
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
534
- conv_attn_to_linear(new_checkpoint)
535
-
536
- for i in range(num_up_blocks):
537
- block_id = num_up_blocks - 1 - i
538
- resnets = [
539
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
540
- ]
541
-
542
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
543
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
544
- f"decoder.up.{block_id}.upsample.conv.weight"
545
- ]
546
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
547
- f"decoder.up.{block_id}.upsample.conv.bias"
548
- ]
549
-
550
- paths = renew_vae_resnet_paths(resnets)
551
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
552
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
553
-
554
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
555
- num_mid_res_blocks = 2
556
- for i in range(1, num_mid_res_blocks + 1):
557
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
558
-
559
- paths = renew_vae_resnet_paths(resnets)
560
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
561
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
562
-
563
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
564
- paths = renew_vae_attention_paths(mid_attentions)
565
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
566
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
567
- conv_attn_to_linear(new_checkpoint)
568
- return new_checkpoint
569
-
570
-
571
- def convert_ldm_bert_checkpoint(checkpoint, config):
572
- def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
573
- hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
574
- hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
575
- hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
576
-
577
- hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
578
- hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
579
-
580
- def _copy_linear(hf_linear, pt_linear):
581
- hf_linear.weight = pt_linear.weight
582
- hf_linear.bias = pt_linear.bias
583
-
584
- def _copy_layer(hf_layer, pt_layer):
585
- # copy layer norms
586
- _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
587
- _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
588
-
589
- # copy attn
590
- _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
591
-
592
- # copy MLP
593
- pt_mlp = pt_layer[1][1]
594
- _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
595
- _copy_linear(hf_layer.fc2, pt_mlp.net[2])
596
-
597
- def _copy_layers(hf_layers, pt_layers):
598
- for i, hf_layer in enumerate(hf_layers):
599
- if i != 0:
600
- i += i
601
- pt_layer = pt_layers[i : i + 2]
602
- _copy_layer(hf_layer, pt_layer)
603
-
604
- hf_model = LDMBertModel(config).eval()
605
-
606
- # copy embeds
607
- hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
608
- hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
609
-
610
- # copy layer norm
611
- _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
612
-
613
- # copy hidden layers
614
- _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
615
-
616
- _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
617
-
618
- return hf_model
619
-
620
-
621
- def convert_ldm_clip_checkpoint(checkpoint):
622
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
623
-
624
- keys = list(checkpoint.keys())
625
-
626
- text_model_dict = {}
627
-
628
- for key in keys:
629
- if key.startswith("cond_stage_model.transformer"):
630
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
631
-
632
- text_model.load_state_dict(text_model_dict)
633
-
634
- return text_model
635
-
636
-
637
- if __name__ == "__main__":
638
- parser = argparse.ArgumentParser()
639
-
640
- parser.add_argument(
641
- "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
642
- )
643
- # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
644
- parser.add_argument(
645
- "--original_config_file",
646
- default=None,
647
- type=str,
648
- help="The YAML config file corresponding to the original architecture.",
649
- )
650
- parser.add_argument(
651
- "--scheduler_type",
652
- default="pndm",
653
- type=str,
654
- help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
655
- )
656
- parser.add_argument(
657
- "--extract_ema",
658
- action="store_true",
659
- help=(
660
- "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
661
- " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
662
- " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
663
- ),
664
- )
665
- parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
666
-
667
- args = parser.parse_args()
668
-
669
- if args.original_config_file is None:
670
- os.system(
671
- "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
672
- )
673
- args.original_config_file = "./v1-inference.yaml"
674
-
675
- original_config = OmegaConf.load(args.original_config_file)
676
-
677
- checkpoint = torch.load(args.checkpoint_path)
678
- checkpoint = checkpoint["state_dict"]
679
-
680
- num_train_timesteps = original_config.model.params.timesteps
681
- beta_start = original_config.model.params.linear_start
682
- beta_end = original_config.model.params.linear_end
683
- if args.scheduler_type == "pndm":
684
- scheduler = PNDMScheduler(
685
- beta_end=beta_end,
686
- beta_schedule="scaled_linear",
687
- beta_start=beta_start,
688
- num_train_timesteps=num_train_timesteps,
689
- skip_prk_steps=True,
690
- )
691
- elif args.scheduler_type == "lms":
692
- scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
693
- elif args.scheduler_type == "euler":
694
- scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
695
- elif args.scheduler_type == "euler-ancestral":
696
- scheduler = EulerAncestralDiscreteScheduler(
697
- beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
698
- )
699
- elif args.scheduler_type == "dpm":
700
- scheduler = DPMSolverMultistepScheduler(
701
- beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
702
- )
703
- elif args.scheduler_type == "ddim":
704
- scheduler = DDIMScheduler(
705
- beta_start=beta_start,
706
- beta_end=beta_end,
707
- beta_schedule="scaled_linear",
708
- clip_sample=False,
709
- set_alpha_to_one=False,
710
- )
711
- else:
712
- raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
713
-
714
- # Convert the UNet2DConditionModel model.
715
- unet_config = create_unet_diffusers_config(original_config)
716
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(
717
- checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
718
- )
719
-
720
- unet = UNet2DConditionModel(**unet_config)
721
- unet.load_state_dict(converted_unet_checkpoint)
722
-
723
- # Convert the VAE model.
724
- vae_config = create_vae_diffusers_config(original_config)
725
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
726
-
727
- vae = AutoencoderKL(**vae_config)
728
- vae.load_state_dict(converted_vae_checkpoint)
729
-
730
- # Convert the text model.
731
- text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
732
- if text_model_type == "FrozenCLIPEmbedder":
733
- text_model = convert_ldm_clip_checkpoint(checkpoint)
734
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
735
- # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
736
- # feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
737
- pipe = StableDiffusionPipeline(
738
- vae=vae,
739
- text_encoder=text_model,
740
- tokenizer=tokenizer,
741
- unet=unet,
742
- scheduler=scheduler,
743
- # safety_checker=safety_checker,
744
- # feature_extractor=feature_extractor,
745
- )
746
- else:
747
- text_config = create_ldm_bert_config(original_config)
748
- text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
749
- tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
750
- pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
751
-
752
- pipe.save_pretrained(args.dump_path)