| | import argparse |
| | import logging |
| |
|
| | import torch |
| | from safetensors import safe_open |
| |
|
| | from diffusers import AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | """ |
| | Usage examples |
| | ============== |
| | |
| | python scripts/convert_hunyuan_image_to_diffusers.py \ |
| | --model_type hunyuanimage2.1 \ |
| | --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1.safetensors" \ |
| | --vae_checkpoint_path "HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \ |
| | --output_path "/raid/yiyi/test-hy21-diffusers" \ |
| | --dtype fp32 |
| | |
| | python scripts/convert_hunyuan_image_to_diffusers.py \ |
| | --model_type hunyuanimage2.1-distilled \ |
| | --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1-distilled.safetensors" \ |
| | --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \ |
| | --output_path "/raid/yiyi/test-hy21-distilled-diffusers" \ |
| | --dtype fp32 |
| | |
| | |
| | python scripts/convert_hunyuan_image_to_diffusers.py \ |
| | --model_type hunyuanimage-refiner \ |
| | --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage-refiner.safetensors" \ |
| | --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_refiner/pytorch_model.pt" \ |
| | --output_path "/raid/yiyi/test-hy2-refiner-diffusers" \ |
| | --dtype fp32 |
| | """ |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_type", type=str, default=None |
| | ) |
| | parser.add_argument("--transformer_checkpoint_path", default=None, type=str) |
| | parser.add_argument("--vae_checkpoint_path", default=None, type=str) |
| | parser.add_argument("--output_path", type=str) |
| | parser.add_argument("--dtype", type=str, default="fp32") |
| |
|
| | args = parser.parse_args() |
| | dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 |
| |
|
| |
|
| | |
| | def convert_hunyuan_dict_for_tensor_parallel(state_dict): |
| | """ |
| | Convert a Hunyuan model state dict to be compatible with tensor parallel architectures. |
| | |
| | Args: |
| | state_dict: Original state dict |
| | |
| | Returns: |
| | new_dict: Converted state dict |
| | """ |
| | new_dict = {} |
| | for k, w in state_dict.items(): |
| | if k.startswith("double_blocks") and "attn_qkv.weight" in k: |
| | hidden_size = w.shape[1] |
| | k1 = k.replace("attn_qkv.weight", "attn_q.weight") |
| | w1 = w[:hidden_size, :] |
| | new_dict[k1] = w1 |
| | k2 = k.replace("attn_qkv.weight", "attn_k.weight") |
| | w2 = w[hidden_size : 2 * hidden_size, :] |
| | new_dict[k2] = w2 |
| | k3 = k.replace("attn_qkv.weight", "attn_v.weight") |
| | w3 = w[-hidden_size:, :] |
| | new_dict[k3] = w3 |
| | elif k.startswith("double_blocks") and "attn_qkv.bias" in k: |
| | hidden_size = w.shape[0] // 3 |
| | k1 = k.replace("attn_qkv.bias", "attn_q.bias") |
| | w1 = w[:hidden_size] |
| | new_dict[k1] = w1 |
| | k2 = k.replace("attn_qkv.bias", "attn_k.bias") |
| | w2 = w[hidden_size : 2 * hidden_size] |
| | new_dict[k2] = w2 |
| | k3 = k.replace("attn_qkv.bias", "attn_v.bias") |
| | w3 = w[-hidden_size:] |
| | new_dict[k3] = w3 |
| | elif k.startswith("single_blocks") and "linear1" in k: |
| | hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0] |
| | k1 = k.replace("linear1", "linear1_q") |
| | w1 = w[:hidden_size] |
| | new_dict[k1] = w1 |
| | k2 = k.replace("linear1", "linear1_k") |
| | w2 = w[hidden_size : 2 * hidden_size] |
| | new_dict[k2] = w2 |
| | k3 = k.replace("linear1", "linear1_v") |
| | w3 = w[2 * hidden_size : 3 * hidden_size] |
| | new_dict[k3] = w3 |
| | k4 = k.replace("linear1", "linear1_mlp") |
| | w4 = w[3 * hidden_size :] |
| | new_dict[k4] = w4 |
| | elif k.startswith("single_blocks") and "linear2" in k: |
| | k1 = k.replace("linear2", "linear2.fc") |
| | new_dict[k1] = w |
| | else: |
| | new_dict[k] = w |
| | return new_dict |
| |
|
| |
|
| | def load_original_vae_checkpoint(args): |
| | |
| | state_dict = torch.load(args.vae_checkpoint_path) |
| |
|
| | if "state_dict" in state_dict: |
| | state_dict = state_dict["state_dict"] |
| | vae_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith("vae."): |
| | vae_state_dict[k.replace("vae.", "")] = v |
| |
|
| | for k, v in vae_state_dict.items(): |
| | if "weight" in k: |
| | if len(v.shape) == 5 and v.shape[2] == 1: |
| | vae_state_dict[k] = v.squeeze(2) |
| | else: |
| | vae_state_dict[k] = v |
| | else: |
| | vae_state_dict[k] = v |
| | return vae_state_dict |
| |
|
| |
|
| | def load_original_refiner_vae_checkpoint(args): |
| | |
| | state_dict = torch.load(args.vae_checkpoint_path) |
| |
|
| | if "state_dict" in state_dict: |
| | state_dict = state_dict["state_dict"] |
| | vae_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith("vae."): |
| | vae_state_dict[k.replace("vae.", "")] = v |
| | return vae_state_dict |
| |
|
| |
|
| | def load_original_transformer_checkpoint(args): |
| | |
| | |
| | state_dict = {} |
| | with safe_open(args.transformer_checkpoint_path, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | state_dict[key] = f.get_tensor(key) |
| | if args.model_type == "hunyuanimage-2.1": |
| | state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict) |
| | return state_dict |
| |
|
| |
|
| | def convert_hunyuan_image_transformer_checkpoint_to_diffusers( |
| | original_state_dict, use_byt5=True, guidance_distilled=False, use_meanflow=False |
| | ): |
| | converted_state_dict = {} |
| |
|
| | |
| | if use_byt5: |
| | converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight") |
| | converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias") |
| | converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight") |
| | converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias") |
| | converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight") |
| | converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias") |
| | converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight") |
| | converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias") |
| |
|
| | |
| | converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight") |
| | converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias") |
| |
|
| | |
| | |
| | converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight") |
| | converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias") |
| |
|
| | |
| | converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( |
| | original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( |
| | "txt_in.t_embedder.mlp.0.bias" |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = ( |
| | original_state_dict.pop("txt_in.t_embedder.mlp.2.weight") |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( |
| | "txt_in.t_embedder.mlp.2.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( |
| | "txt_in.c_embedder.linear_1.weight" |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( |
| | "txt_in.c_embedder.linear_1.bias" |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( |
| | "txt_in.c_embedder.linear_2.weight" |
| | ) |
| | converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( |
| | "txt_in.c_embedder.linear_2.bias" |
| | ) |
| |
|
| | |
| | for i in range(2): |
| | block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}." |
| | |
| | converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.norm1.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.norm1.bias" |
| | ) |
| | |
| | converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.norm2.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.norm2.bias" |
| | ) |
| |
|
| | |
| | qkv_weight = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.weight") |
| | qkv_bias = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.bias") |
| | q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0) |
| | q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) |
| |
|
| | converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q_weight |
| | converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias |
| | converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k_weight |
| | converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias |
| | converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v_weight |
| | converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop( |
| | f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( |
| | "time_in.mlp.0.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( |
| | "time_in.mlp.0.bias" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( |
| | "time_in.mlp.2.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( |
| | "time_in.mlp.2.bias" |
| | ) |
| |
|
| | |
| | if use_meanflow: |
| | converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop( |
| | "time_r_in.mlp.0.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop( |
| | "time_r_in.mlp.0.bias" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop( |
| | "time_r_in.mlp.2.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop( |
| | "time_r_in.mlp.2.bias" |
| | ) |
| |
|
| | |
| | if guidance_distilled: |
| | converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop( |
| | "guidance_in.mlp.0.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop( |
| | "guidance_in.mlp.0.bias" |
| | ) |
| | converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop( |
| | "guidance_in.mlp.2.weight" |
| | ) |
| | converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop( |
| | "guidance_in.mlp.2.bias" |
| | ) |
| |
|
| | |
| | for i in range(20): |
| | block_prefix = f"transformer_blocks.{i}." |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mod.linear.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mod.linear.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mod.linear.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mod.linear.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_q.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_q.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_k.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_k.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_v.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_v.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_q_norm.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_k_norm.weight" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_proj.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_attn_proj.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mlp.fc1.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mlp.fc1.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mlp.fc2.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.img_mlp.fc2.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_q.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_q.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_k.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_k.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_v.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_v.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_q_norm.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_k_norm.weight" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_proj.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_attn_proj.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mlp.fc1.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mlp.fc1.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mlp.fc2.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( |
| | f"double_blocks.{i}.txt_mlp.fc2.bias" |
| | ) |
| |
|
| | |
| | for i in range(40): |
| | block_prefix = f"single_transformer_blocks.{i}." |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.modulation.linear.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.modulation.linear.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_q.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_q.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_k.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_k.bias" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_v.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_v.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.q_norm.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.k_norm.weight" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}proj_mlp.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_mlp.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}proj_mlp.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear1_mlp.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear2.fc.weight" |
| | ) |
| | converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( |
| | f"single_blocks.{i}.linear2.fc.bias" |
| | ) |
| |
|
| | |
| | converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") |
| | converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") |
| | shift_w, scale_w = original_state_dict.pop("final_layer.adaLN_modulation.1.weight").chunk(2, dim=0) |
| | shift_b, scale_b = original_state_dict.pop("final_layer.adaLN_modulation.1.bias").chunk(2, dim=0) |
| | converted_state_dict["norm_out.linear.weight"] = torch.cat([scale_w, shift_w], dim=0) |
| | converted_state_dict["norm_out.linear.bias"] = torch.cat([scale_b, shift_b], dim=0) |
| |
|
| | return converted_state_dict, original_state_dict |
| |
|
| |
|
| | def convert_hunyuan_image_vae_checkpoint_to_diffusers( |
| | original_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2 |
| | ): |
| | """Convert original VAE state dict to Diffusers format.""" |
| | converted = {} |
| |
|
| | |
| | |
| | converted["encoder.conv_in.weight"] = original_state_dict.pop("encoder.conv_in.weight") |
| | converted["encoder.conv_in.bias"] = original_state_dict.pop("encoder.conv_in.bias") |
| |
|
| | |
| | diffusers_block_idx = 0 |
| |
|
| | for block_index in range(len(block_out_channels)): |
| | for resnet_block_index in range(layers_per_block): |
| | orig_prefix = f"encoder.down.{block_index}.block.{resnet_block_index}" |
| | diff_prefix = f"encoder.down_blocks.{diffusers_block_idx}" |
| |
|
| | |
| | converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight") |
| | converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias") |
| | converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight") |
| | converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias") |
| | converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight") |
| | converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias") |
| | converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight") |
| | converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias") |
| |
|
| | diffusers_block_idx += 1 |
| |
|
| | |
| | if f"encoder.down.{block_index}.downsample.conv.weight" in original_state_dict: |
| | converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop( |
| | f"encoder.down.{block_index}.downsample.conv.weight" |
| | ) |
| | converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop( |
| | f"encoder.down.{block_index}.downsample.conv.bias" |
| | ) |
| | diffusers_block_idx += 1 |
| |
|
| | |
| | converted["encoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("encoder.mid.block_1.norm1.weight") |
| | converted["encoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("encoder.mid.block_1.norm1.bias") |
| | converted["encoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("encoder.mid.block_1.conv1.weight") |
| | converted["encoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("encoder.mid.block_1.conv1.bias") |
| | converted["encoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("encoder.mid.block_1.norm2.weight") |
| | converted["encoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("encoder.mid.block_1.norm2.bias") |
| | converted["encoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("encoder.mid.block_1.conv2.weight") |
| | converted["encoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("encoder.mid.block_1.conv2.bias") |
| |
|
| | converted["encoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("encoder.mid.block_2.norm1.weight") |
| | converted["encoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("encoder.mid.block_2.norm1.bias") |
| | converted["encoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("encoder.mid.block_2.conv1.weight") |
| | converted["encoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("encoder.mid.block_2.conv1.bias") |
| | converted["encoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("encoder.mid.block_2.norm2.weight") |
| | converted["encoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("encoder.mid.block_2.norm2.bias") |
| | converted["encoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("encoder.mid.block_2.conv2.weight") |
| | converted["encoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("encoder.mid.block_2.conv2.bias") |
| |
|
| | converted["encoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("encoder.mid.attn_1.norm.weight") |
| | converted["encoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("encoder.mid.attn_1.norm.bias") |
| | converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight") |
| | converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias") |
| | converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight") |
| | converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias") |
| | converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight") |
| | converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias") |
| | converted["encoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop( |
| | "encoder.mid.attn_1.proj_out.weight" |
| | ) |
| | converted["encoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("encoder.mid.attn_1.proj_out.bias") |
| |
|
| | |
| | converted["encoder.norm_out.weight"] = original_state_dict.pop("encoder.norm_out.weight") |
| | converted["encoder.norm_out.bias"] = original_state_dict.pop("encoder.norm_out.bias") |
| | converted["encoder.conv_out.weight"] = original_state_dict.pop("encoder.conv_out.weight") |
| | converted["encoder.conv_out.bias"] = original_state_dict.pop("encoder.conv_out.bias") |
| |
|
| | |
| | |
| | converted["decoder.conv_in.weight"] = original_state_dict.pop("decoder.conv_in.weight") |
| | converted["decoder.conv_in.bias"] = original_state_dict.pop("decoder.conv_in.bias") |
| |
|
| | |
| | converted["decoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("decoder.mid.block_1.norm1.weight") |
| | converted["decoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("decoder.mid.block_1.norm1.bias") |
| | converted["decoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("decoder.mid.block_1.conv1.weight") |
| | converted["decoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("decoder.mid.block_1.conv1.bias") |
| | converted["decoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("decoder.mid.block_1.norm2.weight") |
| | converted["decoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("decoder.mid.block_1.norm2.bias") |
| | converted["decoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("decoder.mid.block_1.conv2.weight") |
| | converted["decoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("decoder.mid.block_1.conv2.bias") |
| |
|
| | converted["decoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("decoder.mid.block_2.norm1.weight") |
| | converted["decoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("decoder.mid.block_2.norm1.bias") |
| | converted["decoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("decoder.mid.block_2.conv1.weight") |
| | converted["decoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("decoder.mid.block_2.conv1.bias") |
| | converted["decoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("decoder.mid.block_2.norm2.weight") |
| | converted["decoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("decoder.mid.block_2.norm2.bias") |
| | converted["decoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("decoder.mid.block_2.conv2.weight") |
| | converted["decoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("decoder.mid.block_2.conv2.bias") |
| |
|
| | converted["decoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("decoder.mid.attn_1.norm.weight") |
| | converted["decoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("decoder.mid.attn_1.norm.bias") |
| | converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight") |
| | converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias") |
| | converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight") |
| | converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias") |
| | converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight") |
| | converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias") |
| | converted["decoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop( |
| | "decoder.mid.attn_1.proj_out.weight" |
| | ) |
| | converted["decoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("decoder.mid.attn_1.proj_out.bias") |
| |
|
| | |
| | diffusers_block_idx = 0 |
| | for up_block_index in range(len(block_out_channels)): |
| | |
| | for resnet_block_index in range(layers_per_block + 1): |
| | orig_prefix = f"decoder.up.{up_block_index}.block.{resnet_block_index}" |
| | diff_prefix = f"decoder.up_blocks.{diffusers_block_idx}" |
| |
|
| | converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight") |
| | converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias") |
| | converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight") |
| | converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias") |
| | converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight") |
| | converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias") |
| | converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight") |
| | converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias") |
| |
|
| | diffusers_block_idx += 1 |
| |
|
| | |
| | if f"decoder.up.{up_block_index}.upsample.conv.weight" in original_state_dict: |
| | converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop( |
| | f"decoder.up.{up_block_index}.upsample.conv.weight" |
| | ) |
| | converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop( |
| | f"decoder.up.{up_block_index}.upsample.conv.bias" |
| | ) |
| | diffusers_block_idx += 1 |
| |
|
| | |
| | converted["decoder.norm_out.weight"] = original_state_dict.pop("decoder.norm_out.weight") |
| | converted["decoder.norm_out.bias"] = original_state_dict.pop("decoder.norm_out.bias") |
| | converted["decoder.conv_out.weight"] = original_state_dict.pop("decoder.conv_out.weight") |
| | converted["decoder.conv_out.bias"] = original_state_dict.pop("decoder.conv_out.bias") |
| |
|
| | return converted, original_state_dict |
| |
|
| |
|
| | def convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers( |
| | original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2 |
| | ): |
| | converted = {} |
| |
|
| | |
| | |
| | converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight") |
| | converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias") |
| |
|
| | |
| | for down_block_index in range(len(block_out_channels)): |
| | |
| | for resnet_block_index in range(layers_per_block): |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma") |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( |
| | original_state_dict.pop( |
| | f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight" |
| | ) |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias") |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma") |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( |
| | original_state_dict.pop( |
| | f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight" |
| | ) |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias") |
| | ) |
| |
|
| | |
| | if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict: |
| | converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight") |
| | ) |
| | converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = ( |
| | original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias") |
| | ) |
| |
|
| | |
| | converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma") |
| | converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( |
| | "encoder.mid.block_1.conv1.conv.weight" |
| | ) |
| | converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( |
| | "encoder.mid.block_1.conv1.conv.bias" |
| | ) |
| | converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma") |
| | converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( |
| | "encoder.mid.block_1.conv2.conv.weight" |
| | ) |
| | converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( |
| | "encoder.mid.block_1.conv2.conv.bias" |
| | ) |
| |
|
| | converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma") |
| | converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( |
| | "encoder.mid.block_2.conv1.conv.weight" |
| | ) |
| | converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( |
| | "encoder.mid.block_2.conv1.conv.bias" |
| | ) |
| | converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma") |
| | converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( |
| | "encoder.mid.block_2.conv2.conv.weight" |
| | ) |
| | converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( |
| | "encoder.mid.block_2.conv2.conv.bias" |
| | ) |
| |
|
| | |
| | converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma") |
| | converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight") |
| | converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias") |
| | converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight") |
| | converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias") |
| | converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight") |
| | converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias") |
| | converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( |
| | "encoder.mid.attn_1.proj_out.weight" |
| | ) |
| | converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( |
| | "encoder.mid.attn_1.proj_out.bias" |
| | ) |
| |
|
| | |
| | converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma") |
| | converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight") |
| | converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias") |
| |
|
| | |
| | |
| | converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight") |
| | converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias") |
| |
|
| | |
| | converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma") |
| | converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( |
| | "decoder.mid.block_1.conv1.conv.weight" |
| | ) |
| | converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( |
| | "decoder.mid.block_1.conv1.conv.bias" |
| | ) |
| | converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma") |
| | converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( |
| | "decoder.mid.block_1.conv2.conv.weight" |
| | ) |
| | converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( |
| | "decoder.mid.block_1.conv2.conv.bias" |
| | ) |
| |
|
| | converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma") |
| | converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( |
| | "decoder.mid.block_2.conv1.conv.weight" |
| | ) |
| | converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( |
| | "decoder.mid.block_2.conv1.conv.bias" |
| | ) |
| | converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma") |
| | converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( |
| | "decoder.mid.block_2.conv2.conv.weight" |
| | ) |
| | converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( |
| | "decoder.mid.block_2.conv2.conv.bias" |
| | ) |
| |
|
| | |
| | converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma") |
| | converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight") |
| | converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias") |
| | converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight") |
| | converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias") |
| | converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight") |
| | converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias") |
| | converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( |
| | "decoder.mid.attn_1.proj_out.weight" |
| | ) |
| | converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( |
| | "decoder.mid.attn_1.proj_out.bias" |
| | ) |
| |
|
| | |
| | for up_block_index in range(len(block_out_channels)): |
| | |
| | for resnet_block_index in range(layers_per_block + 1): |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma") |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight") |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias") |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma") |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight") |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( |
| | original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias") |
| | ) |
| |
|
| | |
| | if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict: |
| | converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop( |
| | f"decoder.up.{up_block_index}.upsample.conv.conv.weight" |
| | ) |
| | converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop( |
| | f"decoder.up.{up_block_index}.upsample.conv.conv.bias" |
| | ) |
| |
|
| | |
| | converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma") |
| | converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight") |
| | converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias") |
| |
|
| | return converted, original_state_dict |
| |
|
| |
|
| | def main(args): |
| | if args.model_type == "hunyuanimage2.1": |
| | original_transformer_state_dict = load_original_transformer_checkpoint(args) |
| | original_vae_state_dict = load_original_vae_checkpoint(args) |
| |
|
| | transformer_config = { |
| | "in_channels": 64, |
| | "out_channels": 64, |
| | "num_attention_heads": 28, |
| | "attention_head_dim": 128, |
| | "num_layers": 20, |
| | "num_single_layers": 40, |
| | "num_refiner_layers": 2, |
| | "patch_size": (1, 1), |
| | "qk_norm": "rms_norm", |
| | "guidance_embeds": False, |
| | "text_embed_dim": 3584, |
| | "text_embed_2_dim": 1472, |
| | "rope_theta": 256.0, |
| | "rope_axes_dim": (64, 64), |
| | } |
| |
|
| | converted_transformer_state_dict, original_transformer_state_dict = ( |
| | convert_hunyuan_image_transformer_checkpoint_to_diffusers( |
| | original_transformer_state_dict, use_byt5=True, guidance_distilled=False |
| | ) |
| | ) |
| |
|
| | if original_transformer_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}" |
| | ) |
| |
|
| | transformer = HunyuanImageTransformer2DModel(**transformer_config) |
| | missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| |
|
| | if missing_keys: |
| | logger.warning(f"Missing keys for transformer: {missing_keys}") |
| | if unexpected_key: |
| | logger.warning(f"Unexpected keys for transformer: {unexpected_key}") |
| |
|
| | transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") |
| |
|
| | vae_config_diffusers = { |
| | "in_channels": 3, |
| | "out_channels": 3, |
| | "latent_channels": 64, |
| | "block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| | "layers_per_block": 2, |
| | "spatial_compression_ratio": 32, |
| | "sample_size": 384, |
| | "scaling_factor": 0.75289, |
| | "downsample_match_channel": True, |
| | "upsample_match_channel": True, |
| | } |
| | converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers( |
| | original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2 |
| | ) |
| | if original_vae_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}" |
| | ) |
| |
|
| | vae = AutoencoderKLHunyuanImage(**vae_config_diffusers) |
| | missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True) |
| |
|
| | if missing_keys: |
| | logger.warning(f"Missing keys for vae: {missing_keys}") |
| | if unexpected_key: |
| | logger.warning(f"Unexpected keys for vae: {unexpected_key}") |
| |
|
| | vae.to(dtype).save_pretrained(f"{args.output_path}/vae") |
| |
|
| | elif args.model_type == "hunyuanimage2.1-distilled": |
| | original_transformer_state_dict = load_original_transformer_checkpoint(args) |
| | original_vae_state_dict = load_original_vae_checkpoint(args) |
| |
|
| | transformer_config = { |
| | "in_channels": 64, |
| | "out_channels": 64, |
| | "num_attention_heads": 28, |
| | "attention_head_dim": 128, |
| | "num_layers": 20, |
| | "num_single_layers": 40, |
| | "num_refiner_layers": 2, |
| | "patch_size": (1, 1), |
| | "qk_norm": "rms_norm", |
| | "guidance_embeds": True, |
| | "text_embed_dim": 3584, |
| | "text_embed_2_dim": 1472, |
| | "rope_theta": 256.0, |
| | "rope_axes_dim": (64, 64), |
| | "use_meanflow": True, |
| | } |
| |
|
| | converted_transformer_state_dict, original_transformer_state_dict = ( |
| | convert_hunyuan_image_transformer_checkpoint_to_diffusers( |
| | original_transformer_state_dict, use_byt5=True, guidance_distilled=True, use_meanflow=True |
| | ) |
| | ) |
| |
|
| | if original_transformer_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}" |
| | ) |
| |
|
| | transformer = HunyuanImageTransformer2DModel(**transformer_config) |
| | missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| |
|
| | if missing_keys: |
| | logger.warning(f"Missing keys for transformer: {missing_keys}") |
| | if unexpected_key: |
| | logger.warning(f"Unexpected keys for transformer: {unexpected_key}") |
| |
|
| | transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") |
| |
|
| | vae_config_diffusers = { |
| | "in_channels": 3, |
| | "out_channels": 3, |
| | "latent_channels": 64, |
| | "block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| | "layers_per_block": 2, |
| | "spatial_compression_ratio": 32, |
| | "sample_size": 384, |
| | "scaling_factor": 0.75289, |
| | "downsample_match_channel": True, |
| | "upsample_match_channel": True, |
| | } |
| | converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers( |
| | original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2 |
| | ) |
| | if original_vae_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}" |
| | ) |
| |
|
| | vae = AutoencoderKLHunyuanImage(**vae_config_diffusers) |
| | missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True) |
| |
|
| | if missing_keys: |
| | logger.warning(f"Missing keys for vae: {missing_keys}") |
| | if unexpected_key: |
| | logger.warning(f"Unexpected keys for vae: {unexpected_key}") |
| |
|
| | vae.to(dtype).save_pretrained(f"{args.output_path}/vae") |
| |
|
| | elif args.model_type == "hunyuanimage-refiner": |
| | original_transformer_state_dict = load_original_transformer_checkpoint(args) |
| | original_vae_state_dict = load_original_refiner_vae_checkpoint(args) |
| |
|
| | transformer_config = { |
| | "in_channels": 128, |
| | "out_channels": 64, |
| | "num_layers": 20, |
| | "num_single_layers": 40, |
| | "rope_axes_dim": [16, 56, 56], |
| | "num_attention_heads": 26, |
| | "attention_head_dim": 128, |
| | "mlp_ratio": 4, |
| | "patch_size": (1, 1, 1), |
| | "text_embed_dim": 3584, |
| | "guidance_embeds": True, |
| | } |
| | converted_transformer_state_dict, original_transformer_state_dict = ( |
| | convert_hunyuan_image_transformer_checkpoint_to_diffusers( |
| | original_transformer_state_dict, use_byt5=False, guidance_distilled=True |
| | ) |
| | ) |
| | if original_transformer_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}" |
| | ) |
| |
|
| | transformer = HunyuanImageTransformer2DModel(**transformer_config) |
| | missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| | if missing_keys: |
| | logger.warning(f"Missing keys for transformer: {missing_keys}") |
| | if unexpected_key: |
| | logger.warning(f"Unexpected keys for transformer: {unexpected_key}") |
| |
|
| | transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") |
| |
|
| | vae = AutoencoderKLHunyuanImageRefiner() |
| |
|
| | converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers( |
| | original_vae_state_dict |
| | ) |
| | if original_vae_state_dict: |
| | logger.warning( |
| | f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}" |
| | ) |
| |
|
| | missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True) |
| | logger.warning(f"Missing keys for vae: {missing_keys}") |
| | logger.warning(f"Unexpected keys for vae: {unexpected_key}") |
| |
|
| | vae.to(dtype).save_pretrained(f"{args.output_path}/vae") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main(args) |
| |
|