| | import argparse |
| | import os |
| | import tempfile |
| |
|
| | import torch |
| | from accelerate import load_checkpoint_and_dispatch |
| |
|
| | from diffusers import UNet2DConditionModel |
| | from diffusers.models.transformers.prior_transformer import PriorTransformer |
| | from diffusers.models.vq_model import VQModel |
| |
|
| |
|
| | """ |
| | Example - From the diffusers root directory: |
| | |
| | Download weights: |
| | ```sh |
| | $ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/prior_fp16.ckpt |
| | ``` |
| | |
| | Convert the model: |
| | ```sh |
| | python scripts/convert_kandinsky_to_diffusers.py \ |
| | --prior_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/prior_fp16.ckpt \ |
| | --clip_stat_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/ViT-L-14_stats.th \ |
| | --text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/decoder_fp16.ckpt \ |
| | --inpaint_text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/inpainting_fp16.ckpt \ |
| | --movq_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/movq_final.ckpt \ |
| | --dump_path /home/yiyi_huggingface_co/dump \ |
| | --debug decoder |
| | ``` |
| | """ |
| |
|
| |
|
| | |
| |
|
| | PRIOR_ORIGINAL_PREFIX = "model" |
| |
|
| | |
| | PRIOR_CONFIG = {} |
| |
|
| |
|
| | def prior_model_from_original_config(): |
| | model = PriorTransformer(**PRIOR_CONFIG) |
| |
|
| | return model |
| |
|
| |
|
| | def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"], |
| | "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"], |
| | "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"], |
| | "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"], |
| | "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]}) |
| |
|
| | |
| | diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]}) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"], |
| | "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"], |
| | } |
| | ) |
| |
|
| | |
| | for idx in range(len(model.transformer_blocks)): |
| | diffusers_transformer_prefix = f"transformer_blocks.{idx}" |
| | original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}" |
| |
|
| | |
| | diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" |
| | original_attention_prefix = f"{original_transformer_prefix}.attn" |
| | diffusers_checkpoint.update( |
| | prior_attention_to_diffusers( |
| | checkpoint, |
| | diffusers_attention_prefix=diffusers_attention_prefix, |
| | original_attention_prefix=original_attention_prefix, |
| | attention_head_dim=model.attention_head_dim, |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" |
| | original_ff_prefix = f"{original_transformer_prefix}.mlp" |
| | diffusers_checkpoint.update( |
| | prior_ff_to_diffusers( |
| | checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ |
| | f"{original_transformer_prefix}.ln_1.weight" |
| | ], |
| | f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ |
| | f"{original_transformer_prefix}.ln_2.weight" |
| | ], |
| | f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"], |
| | "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"], |
| | "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"], |
| | } |
| | ) |
| |
|
| | |
| | clip_mean, clip_std = clip_stats_checkpoint |
| | clip_mean = clip_mean[None, :] |
| | clip_std = clip_std[None, :] |
| |
|
| | diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std}) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def prior_attention_to_diffusers( |
| | checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim |
| | ): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( |
| | weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], |
| | bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], |
| | split=3, |
| | chunk_size=attention_head_dim, |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.to_q.weight": q_weight, |
| | f"{diffusers_attention_prefix}.to_q.bias": q_bias, |
| | f"{diffusers_attention_prefix}.to_k.weight": k_weight, |
| | f"{diffusers_attention_prefix}.to_k.bias": k_bias, |
| | f"{diffusers_attention_prefix}.to_v.weight": v_weight, |
| | f"{diffusers_attention_prefix}.to_v.bias": v_bias, |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], |
| | f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): |
| | diffusers_checkpoint = { |
| | |
| | f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], |
| | f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], |
| | |
| | f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], |
| | f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], |
| | } |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | UNET_CONFIG = { |
| | "act_fn": "silu", |
| | "addition_embed_type": "text_image", |
| | "addition_embed_type_num_heads": 64, |
| | "attention_head_dim": 64, |
| | "block_out_channels": [384, 768, 1152, 1536], |
| | "center_input_sample": False, |
| | "class_embed_type": None, |
| | "class_embeddings_concat": False, |
| | "conv_in_kernel": 3, |
| | "conv_out_kernel": 3, |
| | "cross_attention_dim": 768, |
| | "cross_attention_norm": None, |
| | "down_block_types": [ |
| | "ResnetDownsampleBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | ], |
| | "downsample_padding": 1, |
| | "dual_cross_attention": False, |
| | "encoder_hid_dim": 1024, |
| | "encoder_hid_dim_type": "text_image_proj", |
| | "flip_sin_to_cos": True, |
| | "freq_shift": 0, |
| | "in_channels": 4, |
| | "layers_per_block": 3, |
| | "mid_block_only_cross_attention": None, |
| | "mid_block_scale_factor": 1, |
| | "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", |
| | "norm_eps": 1e-05, |
| | "norm_num_groups": 32, |
| | "num_class_embeds": None, |
| | "only_cross_attention": False, |
| | "out_channels": 8, |
| | "projection_class_embeddings_input_dim": None, |
| | "resnet_out_scale_factor": 1.0, |
| | "resnet_skip_time_act": False, |
| | "resnet_time_scale_shift": "scale_shift", |
| | "sample_size": 64, |
| | "time_cond_proj_dim": None, |
| | "time_embedding_act_fn": None, |
| | "time_embedding_dim": None, |
| | "time_embedding_type": "positional", |
| | "timestep_post_act": None, |
| | "up_block_types": [ |
| | "SimpleCrossAttnUpBlock2D", |
| | "SimpleCrossAttnUpBlock2D", |
| | "SimpleCrossAttnUpBlock2D", |
| | "ResnetUpsampleBlock2D", |
| | ], |
| | "upcast_attention": False, |
| | "use_linear_projection": False, |
| | } |
| |
|
| |
|
| | def unet_model_from_original_config(): |
| | model = UNet2DConditionModel(**UNET_CONFIG) |
| |
|
| | return model |
| |
|
| |
|
| | def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | num_head_channels = UNET_CONFIG["attention_head_dim"] |
| |
|
| | diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) |
| | diffusers_checkpoint.update(unet_conv_in(checkpoint)) |
| | diffusers_checkpoint.update(unet_add_embedding(checkpoint)) |
| | diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) |
| |
|
| | |
| |
|
| | original_down_block_idx = 1 |
| |
|
| | for diffusers_down_block_idx in range(len(model.down_blocks)): |
| | checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | diffusers_down_block_idx=diffusers_down_block_idx, |
| | original_down_block_idx=original_down_block_idx, |
| | num_head_channels=num_head_channels, |
| | ) |
| |
|
| | original_down_block_idx += num_original_down_blocks |
| |
|
| | diffusers_checkpoint.update(checkpoint_update) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update( |
| | unet_midblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | num_head_channels=num_head_channels, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | original_up_block_idx = 0 |
| |
|
| | for diffusers_up_block_idx in range(len(model.up_blocks)): |
| | checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | diffusers_up_block_idx=diffusers_up_block_idx, |
| | original_up_block_idx=original_up_block_idx, |
| | num_head_channels=num_head_channels, |
| | ) |
| |
|
| | original_up_block_idx += num_original_up_blocks |
| |
|
| | diffusers_checkpoint.update(checkpoint_update) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) |
| | diffusers_checkpoint.update(unet_conv_out(checkpoint)) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | INPAINT_UNET_CONFIG = { |
| | "act_fn": "silu", |
| | "addition_embed_type": "text_image", |
| | "addition_embed_type_num_heads": 64, |
| | "attention_head_dim": 64, |
| | "block_out_channels": [384, 768, 1152, 1536], |
| | "center_input_sample": False, |
| | "class_embed_type": None, |
| | "class_embeddings_concat": None, |
| | "conv_in_kernel": 3, |
| | "conv_out_kernel": 3, |
| | "cross_attention_dim": 768, |
| | "cross_attention_norm": None, |
| | "down_block_types": [ |
| | "ResnetDownsampleBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | "SimpleCrossAttnDownBlock2D", |
| | ], |
| | "downsample_padding": 1, |
| | "dual_cross_attention": False, |
| | "encoder_hid_dim": 1024, |
| | "encoder_hid_dim_type": "text_image_proj", |
| | "flip_sin_to_cos": True, |
| | "freq_shift": 0, |
| | "in_channels": 9, |
| | "layers_per_block": 3, |
| | "mid_block_only_cross_attention": None, |
| | "mid_block_scale_factor": 1, |
| | "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", |
| | "norm_eps": 1e-05, |
| | "norm_num_groups": 32, |
| | "num_class_embeds": None, |
| | "only_cross_attention": False, |
| | "out_channels": 8, |
| | "projection_class_embeddings_input_dim": None, |
| | "resnet_out_scale_factor": 1.0, |
| | "resnet_skip_time_act": False, |
| | "resnet_time_scale_shift": "scale_shift", |
| | "sample_size": 64, |
| | "time_cond_proj_dim": None, |
| | "time_embedding_act_fn": None, |
| | "time_embedding_dim": None, |
| | "time_embedding_type": "positional", |
| | "timestep_post_act": None, |
| | "up_block_types": [ |
| | "SimpleCrossAttnUpBlock2D", |
| | "SimpleCrossAttnUpBlock2D", |
| | "SimpleCrossAttnUpBlock2D", |
| | "ResnetUpsampleBlock2D", |
| | ], |
| | "upcast_attention": False, |
| | "use_linear_projection": False, |
| | } |
| |
|
| |
|
| | def inpaint_unet_model_from_original_config(): |
| | model = UNet2DConditionModel(**INPAINT_UNET_CONFIG) |
| |
|
| | return model |
| |
|
| |
|
| | def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"] |
| |
|
| | diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) |
| | diffusers_checkpoint.update(unet_conv_in(checkpoint)) |
| | diffusers_checkpoint.update(unet_add_embedding(checkpoint)) |
| | diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) |
| |
|
| | |
| |
|
| | original_down_block_idx = 1 |
| |
|
| | for diffusers_down_block_idx in range(len(model.down_blocks)): |
| | checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | diffusers_down_block_idx=diffusers_down_block_idx, |
| | original_down_block_idx=original_down_block_idx, |
| | num_head_channels=num_head_channels, |
| | ) |
| |
|
| | original_down_block_idx += num_original_down_blocks |
| |
|
| | diffusers_checkpoint.update(checkpoint_update) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update( |
| | unet_midblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | num_head_channels=num_head_channels, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | original_up_block_idx = 0 |
| |
|
| | for diffusers_up_block_idx in range(len(model.up_blocks)): |
| | checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( |
| | model, |
| | checkpoint, |
| | diffusers_up_block_idx=diffusers_up_block_idx, |
| | original_up_block_idx=original_up_block_idx, |
| | num_head_channels=num_head_channels, |
| | ) |
| |
|
| | original_up_block_idx += num_original_up_blocks |
| |
|
| | diffusers_checkpoint.update(checkpoint_update) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) |
| | diffusers_checkpoint.update(unet_conv_out(checkpoint)) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | def unet_time_embeddings(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "time_embedding.linear_1.weight": checkpoint["time_embed.0.weight"], |
| | "time_embedding.linear_1.bias": checkpoint["time_embed.0.bias"], |
| | "time_embedding.linear_2.weight": checkpoint["time_embed.2.weight"], |
| | "time_embedding.linear_2.bias": checkpoint["time_embed.2.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def unet_conv_in(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "conv_in.weight": checkpoint["input_blocks.0.0.weight"], |
| | "conv_in.bias": checkpoint["input_blocks.0.0.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def unet_add_embedding(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"], |
| | "add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"], |
| | "add_embedding.text_proj.weight": checkpoint["proj_n.weight"], |
| | "add_embedding.text_proj.bias": checkpoint["proj_n.bias"], |
| | "add_embedding.image_proj.weight": checkpoint["img_layer.weight"], |
| | "add_embedding.image_proj.bias": checkpoint["img_layer.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def unet_encoder_hid_proj(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"], |
| | "encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"], |
| | "encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"], |
| | "encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def unet_conv_norm_out(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "conv_norm_out.weight": checkpoint["out.0.weight"], |
| | "conv_norm_out.bias": checkpoint["out.0.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def unet_conv_out(checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "conv_out.weight": checkpoint["out.2.weight"], |
| | "conv_out.bias": checkpoint["out.2.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def unet_downblock_to_diffusers_checkpoint( |
| | model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, num_head_channels |
| | ): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets" |
| | original_down_block_prefix = "input_blocks" |
| |
|
| | down_block = model.down_blocks[diffusers_down_block_idx] |
| |
|
| | num_resnets = len(down_block.resnets) |
| |
|
| | if down_block.downsamplers is None: |
| | downsampler = False |
| | else: |
| | assert len(down_block.downsamplers) == 1 |
| | downsampler = True |
| | |
| | num_resnets += 1 |
| |
|
| | for resnet_idx_inc in range(num_resnets): |
| | full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0" |
| |
|
| | if downsampler and resnet_idx_inc == num_resnets - 1: |
| | |
| | full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0" |
| | else: |
| | |
| | full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" |
| |
|
| | diffusers_checkpoint.update( |
| | resnet_to_diffusers_checkpoint( |
| | checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix |
| | ) |
| | ) |
| |
|
| | if hasattr(down_block, "attentions"): |
| | num_attentions = len(down_block.attentions) |
| | diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions" |
| |
|
| | for attention_idx_inc in range(num_attentions): |
| | full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1" |
| | full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" |
| |
|
| | diffusers_checkpoint.update( |
| | attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | attention_prefix=full_attention_prefix, |
| | diffusers_attention_prefix=full_diffusers_attention_prefix, |
| | num_head_channels=num_head_channels, |
| | ) |
| | ) |
| |
|
| | num_original_down_blocks = num_resnets |
| |
|
| | return diffusers_checkpoint, num_original_down_blocks |
| |
|
| |
|
| | |
| | def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, num_head_channels): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| |
|
| | original_block_idx = 0 |
| |
|
| | diffusers_checkpoint.update( |
| | resnet_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_resnet_prefix="mid_block.resnets.0", |
| | resnet_prefix=f"middle_block.{original_block_idx}", |
| | ) |
| | ) |
| |
|
| | original_block_idx += 1 |
| |
|
| | |
| |
|
| | if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None: |
| | diffusers_checkpoint.update( |
| | attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_attention_prefix="mid_block.attentions.0", |
| | attention_prefix=f"middle_block.{original_block_idx}", |
| | num_head_channels=num_head_channels, |
| | ) |
| | ) |
| | original_block_idx += 1 |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update( |
| | resnet_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_resnet_prefix="mid_block.resnets.1", |
| | resnet_prefix=f"middle_block.{original_block_idx}", |
| | ) |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def unet_upblock_to_diffusers_checkpoint( |
| | model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, num_head_channels |
| | ): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets" |
| | original_up_block_prefix = "output_blocks" |
| |
|
| | up_block = model.up_blocks[diffusers_up_block_idx] |
| |
|
| | num_resnets = len(up_block.resnets) |
| |
|
| | if up_block.upsamplers is None: |
| | upsampler = False |
| | else: |
| | assert len(up_block.upsamplers) == 1 |
| | upsampler = True |
| | |
| | num_resnets += 1 |
| |
|
| | has_attentions = hasattr(up_block, "attentions") |
| |
|
| | for resnet_idx_inc in range(num_resnets): |
| | if upsampler and resnet_idx_inc == num_resnets - 1: |
| | |
| | if has_attentions: |
| | |
| | original_resnet_block_idx = 2 |
| | else: |
| | original_resnet_block_idx = 1 |
| |
|
| | |
| | full_resnet_prefix = ( |
| | f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}" |
| | ) |
| |
|
| | full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0" |
| | else: |
| | |
| | full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0" |
| | full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" |
| |
|
| | diffusers_checkpoint.update( |
| | resnet_to_diffusers_checkpoint( |
| | checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix |
| | ) |
| | ) |
| |
|
| | if has_attentions: |
| | num_attentions = len(up_block.attentions) |
| | diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions" |
| |
|
| | for attention_idx_inc in range(num_attentions): |
| | full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1" |
| | full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" |
| |
|
| | diffusers_checkpoint.update( |
| | attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | attention_prefix=full_attention_prefix, |
| | diffusers_attention_prefix=full_diffusers_attention_prefix, |
| | num_head_channels=num_head_channels, |
| | ) |
| | ) |
| |
|
| | num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets |
| |
|
| | return diffusers_checkpoint, num_original_down_blocks |
| |
|
| |
|
| | def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix): |
| | diffusers_checkpoint = { |
| | f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"], |
| | f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"], |
| | f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"], |
| | f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"], |
| | f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"], |
| | f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"], |
| | f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"], |
| | f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"], |
| | } |
| |
|
| | skip_connection_prefix = f"{resnet_prefix}.skip_connection" |
| |
|
| | if f"{skip_connection_prefix}.weight" in checkpoint: |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"], |
| | f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], |
| | f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], |
| | } |
| | ) |
| |
|
| | |
| | [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( |
| | weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0], |
| | bias=checkpoint[f"{attention_prefix}.qkv.bias"], |
| | split=3, |
| | chunk_size=num_head_channels, |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.to_q.weight": q_weight, |
| | f"{diffusers_attention_prefix}.to_q.bias": q_bias, |
| | f"{diffusers_attention_prefix}.to_k.weight": k_weight, |
| | f"{diffusers_attention_prefix}.to_k.bias": k_bias, |
| | f"{diffusers_attention_prefix}.to_v.weight": v_weight, |
| | f"{diffusers_attention_prefix}.to_v.bias": v_bias, |
| | } |
| | ) |
| |
|
| | |
| | [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( |
| | weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0], |
| | bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"], |
| | split=2, |
| | chunk_size=num_head_channels, |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight, |
| | f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias, |
| | f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight, |
| | f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias, |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ |
| | :, :, 0 |
| | ], |
| | f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | |
| | def split_attentions(*, weight, bias, split, chunk_size): |
| | weights = [None] * split |
| | biases = [None] * split |
| |
|
| | weights_biases_idx = 0 |
| |
|
| | for starting_row_index in range(0, weight.shape[0], chunk_size): |
| | row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) |
| |
|
| | weight_rows = weight[row_indices, :] |
| | bias_rows = bias[row_indices] |
| |
|
| | if weights[weights_biases_idx] is None: |
| | assert weights[weights_biases_idx] is None |
| | weights[weights_biases_idx] = weight_rows |
| | biases[weights_biases_idx] = bias_rows |
| | else: |
| | assert weights[weights_biases_idx] is not None |
| | weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) |
| | biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) |
| |
|
| | weights_biases_idx = (weights_biases_idx + 1) % split |
| |
|
| | return weights, biases |
| |
|
| |
|
| | |
| |
|
| |
|
| | def prior(*, args, checkpoint_map_location): |
| | print("loading prior") |
| |
|
| | prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) |
| |
|
| | clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location) |
| |
|
| | prior_model = prior_model_from_original_config() |
| |
|
| | prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint( |
| | prior_model, prior_checkpoint, clip_stats_checkpoint |
| | ) |
| |
|
| | del prior_checkpoint |
| | del clip_stats_checkpoint |
| |
|
| | load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) |
| |
|
| | print("done loading prior") |
| |
|
| | return prior_model |
| |
|
| |
|
| | def text2img(*, args, checkpoint_map_location): |
| | print("loading text2img") |
| |
|
| | text2img_checkpoint = torch.load(args.text2img_checkpoint_path, map_location=checkpoint_map_location) |
| |
|
| | unet_model = unet_model_from_original_config() |
| |
|
| | unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) |
| |
|
| | del text2img_checkpoint |
| |
|
| | load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) |
| |
|
| | print("done loading text2img") |
| |
|
| | return unet_model |
| |
|
| |
|
| | def inpaint_text2img(*, args, checkpoint_map_location): |
| | print("loading inpaint text2img") |
| |
|
| | inpaint_text2img_checkpoint = torch.load( |
| | args.inpaint_text2img_checkpoint_path, map_location=checkpoint_map_location |
| | ) |
| |
|
| | inpaint_unet_model = inpaint_unet_model_from_original_config() |
| |
|
| | inpaint_unet_diffusers_checkpoint = inpaint_unet_original_checkpoint_to_diffusers_checkpoint( |
| | inpaint_unet_model, inpaint_text2img_checkpoint |
| | ) |
| |
|
| | del inpaint_text2img_checkpoint |
| |
|
| | load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) |
| |
|
| | print("done loading inpaint text2img") |
| |
|
| | return inpaint_unet_model |
| |
|
| |
|
| | |
| |
|
| | MOVQ_CONFIG = { |
| | "in_channels": 3, |
| | "out_channels": 3, |
| | "latent_channels": 4, |
| | "down_block_types": ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), |
| | "up_block_types": ("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), |
| | "num_vq_embeddings": 16384, |
| | "block_out_channels": (128, 256, 256, 512), |
| | "vq_embed_dim": 4, |
| | "layers_per_block": 2, |
| | "norm_type": "spatial", |
| | } |
| |
|
| |
|
| | def movq_model_from_original_config(): |
| | movq = VQModel(**MOVQ_CONFIG) |
| | return movq |
| |
|
| |
|
| | def movq_encoder_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], |
| | "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], |
| | } |
| | ) |
| |
|
| | |
| | for down_block_idx, down_block in enumerate(model.encoder.down_blocks): |
| | diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" |
| | down_block_prefix = f"encoder.down.{down_block_idx}" |
| |
|
| | |
| | for resnet_idx, resnet in enumerate(down_block.resnets): |
| | diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" |
| | resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | movq_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | if down_block_idx != len(model.encoder.down_blocks) - 1: |
| | |
| | |
| | diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" |
| | downsample_prefix = f"{down_block_prefix}.downsample.conv" |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], |
| | f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | if hasattr(down_block, "attentions"): |
| | for attention_idx, _ in enumerate(down_block.attentions): |
| | diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" |
| | attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" |
| | diffusers_checkpoint.update( |
| | movq_attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_attention_prefix=diffusers_attention_prefix, |
| | attention_prefix=attention_prefix, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | diffusers_attention_prefix = "encoder.mid_block.attentions.0" |
| | attention_prefix = "encoder.mid.attn_1" |
| | diffusers_checkpoint.update( |
| | movq_attention_to_diffusers_checkpoint( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): |
| | diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" |
| |
|
| | |
| | orig_resnet_idx = diffusers_resnet_idx + 1 |
| | |
| | resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | movq_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | |
| | "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], |
| | "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], |
| | |
| | "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], |
| | "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def movq_decoder_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], |
| | "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): |
| | |
| | orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx |
| |
|
| | diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" |
| | up_block_prefix = f"decoder.up.{orig_up_block_idx}" |
| |
|
| | |
| | for resnet_idx, resnet in enumerate(up_block.resnets): |
| | diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" |
| | resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | movq_resnet_to_diffusers_checkpoint_spatial_norm( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| | if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: |
| | |
| | |
| | diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" |
| | downsample_prefix = f"{up_block_prefix}.upsample.conv" |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], |
| | f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | if hasattr(up_block, "attentions"): |
| | for attention_idx, _ in enumerate(up_block.attentions): |
| | diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" |
| | attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" |
| | diffusers_checkpoint.update( |
| | movq_attention_to_diffusers_checkpoint_spatial_norm( |
| | checkpoint, |
| | diffusers_attention_prefix=diffusers_attention_prefix, |
| | attention_prefix=attention_prefix, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | diffusers_attention_prefix = "decoder.mid_block.attentions.0" |
| | attention_prefix = "decoder.mid.attn_1" |
| | diffusers_checkpoint.update( |
| | movq_attention_to_diffusers_checkpoint_spatial_norm( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): |
| | diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" |
| |
|
| | |
| | orig_resnet_idx = diffusers_resnet_idx + 1 |
| | |
| | resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | movq_resnet_to_diffusers_checkpoint_spatial_norm( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | |
| | "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], |
| | "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], |
| | "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], |
| | "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], |
| | "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], |
| | "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], |
| | |
| | "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], |
| | "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): |
| | rv = { |
| | |
| | f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], |
| | f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], |
| | f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], |
| | } |
| |
|
| | if resnet.conv_shortcut is not None: |
| | rv.update( |
| | { |
| | f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], |
| | f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], |
| | } |
| | ) |
| |
|
| | return rv |
| |
|
| |
|
| | def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): |
| | rv = { |
| | |
| | f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], |
| | f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], |
| | f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], |
| | f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], |
| | f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], |
| | f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], |
| | f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], |
| | } |
| |
|
| | if resnet.conv_shortcut is not None: |
| | rv.update( |
| | { |
| | f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], |
| | f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], |
| | } |
| | ) |
| |
|
| | return rv |
| |
|
| |
|
| | def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): |
| | return { |
| | |
| | f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], |
| | f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], |
| | } |
| |
|
| |
|
| | def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): |
| | return { |
| | |
| | f"{diffusers_attention_prefix}.spatial_norm.norm_layer.weight": checkpoint[ |
| | f"{attention_prefix}.norm.norm_layer.weight" |
| | ], |
| | f"{diffusers_attention_prefix}.spatial_norm.norm_layer.bias": checkpoint[ |
| | f"{attention_prefix}.norm.norm_layer.bias" |
| | ], |
| | f"{diffusers_attention_prefix}.spatial_norm.conv_y.weight": checkpoint[ |
| | f"{attention_prefix}.norm.conv_y.weight" |
| | ], |
| | f"{diffusers_attention_prefix}.spatial_norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], |
| | f"{diffusers_attention_prefix}.spatial_norm.conv_b.weight": checkpoint[ |
| | f"{attention_prefix}.norm.conv_b.weight" |
| | ], |
| | f"{diffusers_attention_prefix}.spatial_norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], |
| | } |
| |
|
| |
|
| | def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| | diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint)) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "quant_conv.weight": checkpoint["quant_conv.weight"], |
| | "quant_conv.bias": checkpoint["quant_conv.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], |
| | "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint)) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def movq(*, args, checkpoint_map_location): |
| | print("loading movq") |
| |
|
| | movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) |
| |
|
| | movq_model = movq_model_from_original_config() |
| |
|
| | movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint(movq_model, movq_checkpoint) |
| |
|
| | del movq_checkpoint |
| |
|
| | load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) |
| |
|
| | print("done loading movq") |
| |
|
| | return movq_model |
| |
|
| |
|
| | def load_checkpoint_to_model(checkpoint, model, strict=False): |
| | with tempfile.NamedTemporaryFile(delete=False) as file: |
| | torch.save(checkpoint, file.name) |
| | del checkpoint |
| | if strict: |
| | model.load_state_dict(torch.load(file.name), strict=True) |
| | else: |
| | load_checkpoint_and_dispatch(model, file.name, device_map="auto") |
| | os.remove(file.name) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") |
| |
|
| | parser.add_argument( |
| | "--prior_checkpoint_path", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Path to the prior checkpoint to convert.", |
| | ) |
| | parser.add_argument( |
| | "--clip_stat_path", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Path to the clip stats checkpoint to convert.", |
| | ) |
| | parser.add_argument( |
| | "--text2img_checkpoint_path", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Path to the text2img checkpoint to convert.", |
| | ) |
| | parser.add_argument( |
| | "--movq_checkpoint_path", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Path to the text2img checkpoint to convert.", |
| | ) |
| | parser.add_argument( |
| | "--inpaint_text2img_checkpoint_path", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Path to the inpaint text2img checkpoint to convert.", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_load_device", |
| | default="cpu", |
| | type=str, |
| | required=False, |
| | help="The device passed to `map_location` when loading checkpoints.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--debug", |
| | default=None, |
| | type=str, |
| | required=False, |
| | help="Only run a specific stage of the convert script. Used for debugging", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | print(f"loading checkpoints to {args.checkpoint_load_device}") |
| |
|
| | checkpoint_map_location = torch.device(args.checkpoint_load_device) |
| |
|
| | if args.debug is not None: |
| | print(f"debug: only executing {args.debug}") |
| |
|
| | if args.debug is None: |
| | print("to-do") |
| | elif args.debug == "prior": |
| | prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) |
| | prior_model.save_pretrained(args.dump_path) |
| | elif args.debug == "text2img": |
| | unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) |
| | unet_model.save_pretrained(f"{args.dump_path}/unet") |
| | elif args.debug == "inpaint_text2img": |
| | inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location) |
| | inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") |
| | elif args.debug == "decoder": |
| | decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) |
| | decoder.save_pretrained(f"{args.dump_path}/decoder") |
| | else: |
| | raise ValueError(f"unknown debug value : {args.debug}") |
| |
|