| | import argparse |
| | import pathlib |
| | from typing import Any, Dict, Tuple |
| |
|
| | import torch |
| | from accelerate import init_empty_weights |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| | from safetensors.torch import load_file |
| | from transformers import ( |
| | AutoProcessor, |
| | AutoTokenizer, |
| | CLIPImageProcessor, |
| | CLIPVisionModel, |
| | CLIPVisionModelWithProjection, |
| | UMT5EncoderModel, |
| | ) |
| |
|
| | from diffusers import ( |
| | AutoencoderKLWan, |
| | UniPCMultistepScheduler, |
| | WanAnimatePipeline, |
| | WanAnimateTransformer3DModel, |
| | WanImageToVideoPipeline, |
| | WanPipeline, |
| | WanTransformer3DModel, |
| | WanVACEPipeline, |
| | WanVACETransformer3DModel, |
| | ) |
| |
|
| |
|
| | TRANSFORMER_KEYS_RENAME_DICT = { |
| | "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| | "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| | "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| | "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| | "time_projection.1": "condition_embedder.time_proj", |
| | "head.modulation": "scale_shift_table", |
| | "head.head": "proj_out", |
| | "modulation": "scale_shift_table", |
| | "ffn.0": "ffn.net.0.proj", |
| | "ffn.2": "ffn.net.2", |
| | |
| | |
| | |
| | "norm2": "norm__placeholder", |
| | "norm3": "norm2", |
| | "norm__placeholder": "norm3", |
| | |
| | "img_emb.proj.0": "condition_embedder.image_embedder.norm1", |
| | "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", |
| | "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", |
| | "img_emb.proj.4": "condition_embedder.image_embedder.norm2", |
| | |
| | "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", |
| | |
| | "self_attn.q": "attn1.to_q", |
| | "self_attn.k": "attn1.to_k", |
| | "self_attn.v": "attn1.to_v", |
| | "self_attn.o": "attn1.to_out.0", |
| | "self_attn.norm_q": "attn1.norm_q", |
| | "self_attn.norm_k": "attn1.norm_k", |
| | "cross_attn.q": "attn2.to_q", |
| | "cross_attn.k": "attn2.to_k", |
| | "cross_attn.v": "attn2.to_v", |
| | "cross_attn.o": "attn2.to_out.0", |
| | "cross_attn.norm_q": "attn2.norm_q", |
| | "cross_attn.norm_k": "attn2.norm_k", |
| | "attn2.to_k_img": "attn2.add_k_proj", |
| | "attn2.to_v_img": "attn2.add_v_proj", |
| | "attn2.norm_k_img": "attn2.norm_added_k", |
| | } |
| |
|
| | VACE_TRANSFORMER_KEYS_RENAME_DICT = { |
| | "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| | "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| | "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| | "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| | "time_projection.1": "condition_embedder.time_proj", |
| | "head.modulation": "scale_shift_table", |
| | "head.head": "proj_out", |
| | "modulation": "scale_shift_table", |
| | "ffn.0": "ffn.net.0.proj", |
| | "ffn.2": "ffn.net.2", |
| | |
| | |
| | |
| | "norm2": "norm__placeholder", |
| | "norm3": "norm2", |
| | "norm__placeholder": "norm3", |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | "self_attn.q": "attn1.to_q", |
| | "self_attn.k": "attn1.to_k", |
| | "self_attn.v": "attn1.to_v", |
| | "self_attn.o": "attn1.to_out.0", |
| | "self_attn.norm_q": "attn1.norm_q", |
| | "self_attn.norm_k": "attn1.norm_k", |
| | "cross_attn.q": "attn2.to_q", |
| | "cross_attn.k": "attn2.to_k", |
| | "cross_attn.v": "attn2.to_v", |
| | "cross_attn.o": "attn2.to_out.0", |
| | "cross_attn.norm_q": "attn2.norm_q", |
| | "cross_attn.norm_k": "attn2.norm_k", |
| | "attn2.to_k_img": "attn2.add_k_proj", |
| | "attn2.to_v_img": "attn2.add_v_proj", |
| | "attn2.norm_k_img": "attn2.norm_added_k", |
| | "before_proj": "proj_in", |
| | "after_proj": "proj_out", |
| | } |
| |
|
| | ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = { |
| | "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| | "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| | "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| | "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| | "time_projection.1": "condition_embedder.time_proj", |
| | "head.modulation": "scale_shift_table", |
| | "head.head": "proj_out", |
| | "modulation": "scale_shift_table", |
| | "ffn.0": "ffn.net.0.proj", |
| | "ffn.2": "ffn.net.2", |
| | |
| | |
| | |
| | "norm2": "norm__placeholder", |
| | "norm3": "norm2", |
| | "norm__placeholder": "norm3", |
| | "img_emb.proj.0": "condition_embedder.image_embedder.norm1", |
| | "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", |
| | "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", |
| | "img_emb.proj.4": "condition_embedder.image_embedder.norm2", |
| | |
| | "self_attn.q": "attn1.to_q", |
| | "self_attn.k": "attn1.to_k", |
| | "self_attn.v": "attn1.to_v", |
| | "self_attn.o": "attn1.to_out.0", |
| | "self_attn.norm_q": "attn1.norm_q", |
| | "self_attn.norm_k": "attn1.norm_k", |
| | "cross_attn.q": "attn2.to_q", |
| | "cross_attn.k": "attn2.to_k", |
| | "cross_attn.v": "attn2.to_v", |
| | "cross_attn.o": "attn2.to_out.0", |
| | "cross_attn.norm_q": "attn2.norm_q", |
| | "cross_attn.norm_k": "attn2.norm_k", |
| | "cross_attn.k_img": "attn2.to_k_img", |
| | "cross_attn.v_img": "attn2.to_v_img", |
| | "cross_attn.norm_k_img": "attn2.norm_k_img", |
| | |
| | "attn2.to_k_img": "attn2.add_k_proj", |
| | "attn2.to_v_img": "attn2.add_v_proj", |
| | "attn2.norm_k_img": "attn2.norm_added_k", |
| | |
| | |
| | |
| | "motion_encoder.enc.fc": "motion_encoder.motion_network", |
| | "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight", |
| | |
| | "face_encoder.conv1_local.conv": "face_encoder.conv1_local", |
| | "face_encoder.conv2.conv": "face_encoder.conv2", |
| | "face_encoder.conv3.conv": "face_encoder.conv3", |
| | |
| | } |
| |
|
| |
|
| | |
| | def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None: |
| | """ |
| | Convert all motion encoder weights for Animate model. |
| | |
| | In the original model: |
| | - All Linear layers in fc use EqualLinear |
| | - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately) |
| | - Blur kernels are stored as buffers in Sequential modules |
| | - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)] |
| | |
| | Conversion strategy: |
| | 1. Drop .kernel buffers (blur kernels) |
| | 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu) |
| | """ |
| | |
| | if ".weight" not in key and ".bias" not in key and ".kernel" not in key: |
| | return |
| |
|
| | |
| | |
| | |
| | if ".kernel" in key and "motion_encoder" in key: |
| | |
| | state_dict.pop(key, None) |
| | return |
| |
|
| | |
| | if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key): |
| | parts = key.split(".") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | convs_idx = parts.index("convs") if "convs" in parts else -1 |
| | if convs_idx >= 0 and len(parts) - convs_idx >= 2: |
| | bias = False |
| | |
| | sequential_idx = int(parts[convs_idx + 1]) |
| | if sequential_idx == 0: |
| | if key.endswith(".weight"): |
| | new_key = "motion_encoder.conv_in.weight" |
| | elif key.endswith(".bias"): |
| | new_key = "motion_encoder.conv_in.act_fn.bias" |
| | bias = True |
| | elif sequential_idx == final_conv_idx: |
| | if key.endswith(".weight"): |
| | new_key = "motion_encoder.conv_out.weight" |
| | else: |
| | |
| | prefix = "motion_encoder.res_blocks." |
| |
|
| | layer_name = parts[convs_idx + 2] |
| | if layer_name == "skip": |
| | layer_name = "conv_skip" |
| |
|
| | if key.endswith(".weight"): |
| | param_name = "weight" |
| | elif key.endswith(".bias"): |
| | param_name = "act_fn.bias" |
| | bias = True |
| |
|
| | suffix_parts = [str(sequential_idx - 1), layer_name, param_name] |
| | suffix = ".".join(suffix_parts) |
| | new_key = prefix + suffix |
| |
|
| | param = state_dict.pop(key) |
| | if bias: |
| | param = param.squeeze() |
| | state_dict[new_key] = param |
| | return |
| | return |
| | return |
| |
|
| |
|
| | def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None: |
| | """ |
| | Convert face adapter weights for the Animate model. |
| | |
| | The original model uses a fused KV projection but the diffusers models uses separate K and V projections. |
| | """ |
| | |
| | if ".weight" not in key and ".bias" not in key: |
| | return |
| |
|
| | prefix = "face_adapter." |
| | if ".fuser_blocks." in key: |
| | parts = key.split(".") |
| |
|
| | module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1 |
| | if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3: |
| | block_idx = parts[module_list_idx + 1] |
| | layer_name = parts[module_list_idx + 2] |
| | param_name = parts[module_list_idx + 3] |
| |
|
| | if layer_name == "linear1_kv": |
| | layer_name_k = "to_k" |
| | layer_name_v = "to_v" |
| |
|
| | suffix_k = ".".join([block_idx, layer_name_k, param_name]) |
| | suffix_v = ".".join([block_idx, layer_name_v, param_name]) |
| | new_key_k = prefix + suffix_k |
| | new_key_v = prefix + suffix_v |
| |
|
| | kv_proj = state_dict.pop(key) |
| | k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0) |
| | state_dict[new_key_k] = k_proj |
| | state_dict[new_key_v] = v_proj |
| | return |
| | else: |
| | if layer_name == "q_norm": |
| | new_layer_name = "norm_q" |
| | elif layer_name == "k_norm": |
| | new_layer_name = "norm_k" |
| | elif layer_name == "linear1_q": |
| | new_layer_name = "to_q" |
| | elif layer_name == "linear2": |
| | new_layer_name = "to_out" |
| |
|
| | suffix_parts = [block_idx, new_layer_name, param_name] |
| | suffix = ".".join(suffix_parts) |
| | new_key = prefix + suffix |
| | state_dict[new_key] = state_dict.pop(key) |
| | return |
| | return |
| |
|
| |
|
| | TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
| | VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
| | ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| | "motion_encoder": convert_animate_motion_encoder_weights, |
| | "face_adapter": convert_animate_face_adapter_weights, |
| | } |
| |
|
| |
|
| | def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| | state_dict[new_key] = state_dict.pop(old_key) |
| |
|
| |
|
| | def load_sharded_safetensors(dir: pathlib.Path): |
| | file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) |
| | state_dict = {} |
| | for path in file_paths: |
| | state_dict.update(load_file(path)) |
| | return state_dict |
| |
|
| |
|
| | def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: |
| | if model_type == "Wan-T2V-1.3B": |
| | config = { |
| | "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 8960, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 12, |
| | "num_layers": 30, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-T2V-14B": |
| | config = { |
| | "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-I2V-14B-480p": |
| | config = { |
| | "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", |
| | "diffusers_config": { |
| | "image_dim": 1280, |
| | "added_kv_proj_dim": 5120, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 36, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-I2V-14B-720p": |
| | config = { |
| | "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", |
| | "diffusers_config": { |
| | "image_dim": 1280, |
| | "added_kv_proj_dim": 5120, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 36, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-FLF2V-14B-720P": |
| | config = { |
| | "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", |
| | "diffusers_config": { |
| | "image_dim": 1280, |
| | "added_kv_proj_dim": 5120, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 36, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | "rope_max_seq_len": 1024, |
| | "pos_embed_seq_len": 257 * 2, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-VACE-1.3B": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.1-VACE-1.3B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 8960, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 12, |
| | "num_layers": 30, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], |
| | "vace_in_channels": 96, |
| | }, |
| | } |
| | RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan-VACE-14B": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.1-VACE-14B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], |
| | "vace_in_channels": 96, |
| | }, |
| | } |
| | RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan2.2-VACE-Fun-14B": |
| | config = { |
| | "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], |
| | "vace_in_channels": 96, |
| | }, |
| | } |
| | RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan2.2-I2V-14B-720p": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.2-I2V-A14B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 36, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan2.2-T2V-A14B": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.2-T2V-A14B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 16, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan2.2-TI2V-5B": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.2-TI2V-5B", |
| | "diffusers_config": { |
| | "added_kv_proj_dim": None, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 14336, |
| | "freq_dim": 256, |
| | "in_channels": 48, |
| | "num_attention_heads": 24, |
| | "num_layers": 30, |
| | "out_channels": 48, |
| | "patch_size": [1, 2, 2], |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | }, |
| | } |
| | RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| | elif model_type == "Wan2.2-Animate-14B": |
| | config = { |
| | "model_id": "Wan-AI/Wan2.2-Animate-14B", |
| | "diffusers_config": { |
| | "image_dim": 1280, |
| | "added_kv_proj_dim": 5120, |
| | "attention_head_dim": 128, |
| | "cross_attn_norm": True, |
| | "eps": 1e-06, |
| | "ffn_dim": 13824, |
| | "freq_dim": 256, |
| | "in_channels": 36, |
| | "num_attention_heads": 40, |
| | "num_layers": 40, |
| | "out_channels": 16, |
| | "patch_size": (1, 2, 2), |
| | "qk_norm": "rms_norm_across_heads", |
| | "text_dim": 4096, |
| | "rope_max_seq_len": 1024, |
| | "pos_embed_seq_len": None, |
| | "motion_encoder_size": 512, |
| | "motion_style_dim": 512, |
| | "motion_dim": 20, |
| | "motion_encoder_dim": 512, |
| | "face_encoder_hidden_dim": 1024, |
| | "face_encoder_num_heads": 4, |
| | "inject_face_latents_blocks": 5, |
| | }, |
| | } |
| | RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT |
| | SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| | return config, RENAME_DICT, SPECIAL_KEYS_REMAP |
| |
|
| |
|
| | def convert_transformer(model_type: str, stage: str = None): |
| | config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) |
| |
|
| | diffusers_config = config["diffusers_config"] |
| | model_id = config["model_id"] |
| | model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) |
| |
|
| | if stage is not None: |
| | model_dir = model_dir / stage |
| |
|
| | original_state_dict = load_sharded_safetensors(model_dir) |
| |
|
| | with init_empty_weights(): |
| | if "Animate" in model_type: |
| | transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) |
| | elif "VACE" in model_type: |
| | transformer = WanVACETransformer3DModel.from_config(diffusers_config) |
| | else: |
| | transformer = WanTransformer3DModel.from_config(diffusers_config) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | new_key = key[:] |
| | for replace_key, rename_key in RENAME_DICT.items(): |
| | new_key = new_key.replace(replace_key, rename_key) |
| | update_state_dict_(original_state_dict, key, new_key) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): |
| | if special_key not in key: |
| | continue |
| | handler_fn_inplace(key, original_state_dict) |
| |
|
| | |
| | transformer.load_state_dict(original_state_dict, strict=True, assign=True) |
| |
|
| | |
| | transformer = transformer.to("cpu") |
| |
|
| | return transformer |
| |
|
| |
|
| | def convert_vae(): |
| | vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth") |
| | old_state_dict = torch.load(vae_ckpt_path, weights_only=True) |
| | new_state_dict = {} |
| |
|
| | |
| | middle_key_mapping = { |
| | |
| | "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", |
| | "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", |
| | "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", |
| | "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", |
| | "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", |
| | "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", |
| | "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", |
| | "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", |
| | "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", |
| | "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", |
| | "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", |
| | "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", |
| | |
| | "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", |
| | "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", |
| | "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", |
| | "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", |
| | "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", |
| | "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", |
| | "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", |
| | "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", |
| | "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", |
| | "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", |
| | "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", |
| | "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", |
| | } |
| |
|
| | |
| | attention_mapping = { |
| | |
| | "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", |
| | "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", |
| | "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", |
| | "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", |
| | "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", |
| | |
| | "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", |
| | "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", |
| | "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", |
| | "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", |
| | "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", |
| | } |
| |
|
| | |
| | head_mapping = { |
| | |
| | "encoder.head.0.gamma": "encoder.norm_out.gamma", |
| | "encoder.head.2.bias": "encoder.conv_out.bias", |
| | "encoder.head.2.weight": "encoder.conv_out.weight", |
| | |
| | "decoder.head.0.gamma": "decoder.norm_out.gamma", |
| | "decoder.head.2.bias": "decoder.conv_out.bias", |
| | "decoder.head.2.weight": "decoder.conv_out.weight", |
| | } |
| |
|
| | |
| | quant_mapping = { |
| | "conv1.weight": "quant_conv.weight", |
| | "conv1.bias": "quant_conv.bias", |
| | "conv2.weight": "post_quant_conv.weight", |
| | "conv2.bias": "post_quant_conv.bias", |
| | } |
| |
|
| | |
| | for key, value in old_state_dict.items(): |
| | |
| | if key in middle_key_mapping: |
| | new_key = middle_key_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in attention_mapping: |
| | new_key = attention_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in head_mapping: |
| | new_key = head_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in quant_mapping: |
| | new_key = quant_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key == "encoder.conv1.weight": |
| | new_state_dict["encoder.conv_in.weight"] = value |
| | elif key == "encoder.conv1.bias": |
| | new_state_dict["encoder.conv_in.bias"] = value |
| | |
| | elif key == "decoder.conv1.weight": |
| | new_state_dict["decoder.conv_in.weight"] = value |
| | elif key == "decoder.conv1.bias": |
| | new_state_dict["decoder.conv_in.bias"] = value |
| | |
| | elif key.startswith("encoder.downsamples."): |
| | |
| | new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") |
| |
|
| | |
| | if ".residual.0.gamma" in new_key: |
| | new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| | elif ".residual.2.bias" in new_key: |
| | new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| | elif ".residual.2.weight" in new_key: |
| | new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| | elif ".residual.3.gamma" in new_key: |
| | new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| | elif ".residual.6.bias" in new_key: |
| | new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| | elif ".residual.6.weight" in new_key: |
| | new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| | elif ".shortcut.bias" in new_key: |
| | new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
| | elif ".shortcut.weight" in new_key: |
| | new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
| |
|
| | new_state_dict[new_key] = value |
| |
|
| | |
| | elif key.startswith("decoder.upsamples."): |
| | |
| | parts = key.split(".") |
| | block_idx = int(parts[2]) |
| |
|
| | |
| | if "residual" in key: |
| | if block_idx in [0, 1, 2]: |
| | new_block_idx = 0 |
| | resnet_idx = block_idx |
| | elif block_idx in [4, 5, 6]: |
| | new_block_idx = 1 |
| | resnet_idx = block_idx - 4 |
| | elif block_idx in [8, 9, 10]: |
| | new_block_idx = 2 |
| | resnet_idx = block_idx - 8 |
| | elif block_idx in [12, 13, 14]: |
| | new_block_idx = 3 |
| | resnet_idx = block_idx - 12 |
| | else: |
| | |
| | new_state_dict[key] = value |
| | continue |
| |
|
| | |
| | if ".residual.0.gamma" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" |
| | elif ".residual.2.bias" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" |
| | elif ".residual.2.weight" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" |
| | elif ".residual.3.gamma" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" |
| | elif ".residual.6.bias" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" |
| | elif ".residual.6.weight" in key: |
| | new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" |
| | else: |
| | new_key = key |
| |
|
| | new_state_dict[new_key] = value |
| |
|
| | |
| | elif ".shortcut." in key: |
| | if block_idx == 4: |
| | new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") |
| | new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") |
| | else: |
| | new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| | new_key = new_key.replace(".shortcut.", ".conv_shortcut.") |
| |
|
| | new_state_dict[new_key] = value |
| |
|
| | |
| | elif ".resample." in key or ".time_conv." in key: |
| | if block_idx == 3: |
| | new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") |
| | elif block_idx == 7: |
| | new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") |
| | elif block_idx == 11: |
| | new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") |
| | else: |
| | new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| |
|
| | new_state_dict[new_key] = value |
| | else: |
| | new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| | new_state_dict[new_key] = value |
| | else: |
| | |
| | new_state_dict[key] = value |
| |
|
| | with init_empty_weights(): |
| | vae = AutoencoderKLWan() |
| | vae.load_state_dict(new_state_dict, strict=True, assign=True) |
| | return vae |
| |
|
| |
|
| | vae22_diffusers_config = { |
| | "base_dim": 160, |
| | "z_dim": 48, |
| | "is_residual": True, |
| | "in_channels": 12, |
| | "out_channels": 12, |
| | "decoder_base_dim": 256, |
| | "scale_factor_temporal": 4, |
| | "scale_factor_spatial": 16, |
| | "patch_size": 2, |
| | "latents_mean": [ |
| | -0.2289, |
| | -0.0052, |
| | -0.1323, |
| | -0.2339, |
| | -0.2799, |
| | 0.0174, |
| | 0.1838, |
| | 0.1557, |
| | -0.1382, |
| | 0.0542, |
| | 0.2813, |
| | 0.0891, |
| | 0.1570, |
| | -0.0098, |
| | 0.0375, |
| | -0.1825, |
| | -0.2246, |
| | -0.1207, |
| | -0.0698, |
| | 0.5109, |
| | 0.2665, |
| | -0.2108, |
| | -0.2158, |
| | 0.2502, |
| | -0.2055, |
| | -0.0322, |
| | 0.1109, |
| | 0.1567, |
| | -0.0729, |
| | 0.0899, |
| | -0.2799, |
| | -0.1230, |
| | -0.0313, |
| | -0.1649, |
| | 0.0117, |
| | 0.0723, |
| | -0.2839, |
| | -0.2083, |
| | -0.0520, |
| | 0.3748, |
| | 0.0152, |
| | 0.1957, |
| | 0.1433, |
| | -0.2944, |
| | 0.3573, |
| | -0.0548, |
| | -0.1681, |
| | -0.0667, |
| | ], |
| | "latents_std": [ |
| | 0.4765, |
| | 1.0364, |
| | 0.4514, |
| | 1.1677, |
| | 0.5313, |
| | 0.4990, |
| | 0.4818, |
| | 0.5013, |
| | 0.8158, |
| | 1.0344, |
| | 0.5894, |
| | 1.0901, |
| | 0.6885, |
| | 0.6165, |
| | 0.8454, |
| | 0.4978, |
| | 0.5759, |
| | 0.3523, |
| | 0.7135, |
| | 0.6804, |
| | 0.5833, |
| | 1.4146, |
| | 0.8986, |
| | 0.5659, |
| | 0.7069, |
| | 0.5338, |
| | 0.4889, |
| | 0.4917, |
| | 0.4069, |
| | 0.4999, |
| | 0.6866, |
| | 0.4093, |
| | 0.5709, |
| | 0.6065, |
| | 0.6415, |
| | 0.4944, |
| | 0.5726, |
| | 1.2042, |
| | 0.5458, |
| | 1.6887, |
| | 0.3971, |
| | 1.0600, |
| | 0.3943, |
| | 0.5537, |
| | 0.5444, |
| | 0.4089, |
| | 0.7468, |
| | 0.7744, |
| | ], |
| | "clip_output": False, |
| | } |
| |
|
| |
|
| | def convert_vae_22(): |
| | vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth") |
| | old_state_dict = torch.load(vae_ckpt_path, weights_only=True) |
| | new_state_dict = {} |
| |
|
| | |
| | middle_key_mapping = { |
| | |
| | "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", |
| | "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", |
| | "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", |
| | "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", |
| | "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", |
| | "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", |
| | "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", |
| | "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", |
| | "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", |
| | "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", |
| | "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", |
| | "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", |
| | |
| | "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", |
| | "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", |
| | "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", |
| | "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", |
| | "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", |
| | "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", |
| | "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", |
| | "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", |
| | "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", |
| | "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", |
| | "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", |
| | "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", |
| | } |
| |
|
| | |
| | attention_mapping = { |
| | |
| | "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", |
| | "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", |
| | "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", |
| | "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", |
| | "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", |
| | |
| | "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", |
| | "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", |
| | "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", |
| | "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", |
| | "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", |
| | } |
| |
|
| | |
| | head_mapping = { |
| | |
| | "encoder.head.0.gamma": "encoder.norm_out.gamma", |
| | "encoder.head.2.bias": "encoder.conv_out.bias", |
| | "encoder.head.2.weight": "encoder.conv_out.weight", |
| | |
| | "decoder.head.0.gamma": "decoder.norm_out.gamma", |
| | "decoder.head.2.bias": "decoder.conv_out.bias", |
| | "decoder.head.2.weight": "decoder.conv_out.weight", |
| | } |
| |
|
| | |
| | quant_mapping = { |
| | "conv1.weight": "quant_conv.weight", |
| | "conv1.bias": "quant_conv.bias", |
| | "conv2.weight": "post_quant_conv.weight", |
| | "conv2.bias": "post_quant_conv.bias", |
| | } |
| |
|
| | |
| | for key, value in old_state_dict.items(): |
| | |
| | if key in middle_key_mapping: |
| | new_key = middle_key_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in attention_mapping: |
| | new_key = attention_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in head_mapping: |
| | new_key = head_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key in quant_mapping: |
| | new_key = quant_mapping[key] |
| | new_state_dict[new_key] = value |
| | |
| | elif key == "encoder.conv1.weight": |
| | new_state_dict["encoder.conv_in.weight"] = value |
| | elif key == "encoder.conv1.bias": |
| | new_state_dict["encoder.conv_in.bias"] = value |
| | |
| | elif key == "decoder.conv1.weight": |
| | new_state_dict["decoder.conv_in.weight"] = value |
| | elif key == "decoder.conv1.bias": |
| | new_state_dict["decoder.conv_in.bias"] = value |
| | |
| | elif key.startswith("encoder.downsamples."): |
| | |
| | new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") |
| |
|
| | |
| | if "residual" in new_key or "shortcut" in new_key: |
| | |
| | new_key = new_key.replace(".downsamples.", ".resnets.") |
| |
|
| | |
| | if ".residual.0.gamma" in new_key: |
| | new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| | elif ".residual.2.weight" in new_key: |
| | new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| | elif ".residual.2.bias" in new_key: |
| | new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| | elif ".residual.3.gamma" in new_key: |
| | new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| | elif ".residual.6.weight" in new_key: |
| | new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| | elif ".residual.6.bias" in new_key: |
| | new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| | elif ".shortcut.weight" in new_key: |
| | new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
| | elif ".shortcut.bias" in new_key: |
| | new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
| |
|
| | |
| | elif "resample" in new_key or "time_conv" in new_key: |
| | |
| | parts = new_key.split(".") |
| | |
| | |
| | if len(parts) >= 4 and parts[3] == "downsamples": |
| | |
| | new_parts = parts[:3] + ["downsampler"] + parts[5:] |
| | new_key = ".".join(new_parts) |
| |
|
| | new_state_dict[new_key] = value |
| |
|
| | |
| | elif key.startswith("decoder.upsamples."): |
| | |
| | new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| |
|
| | |
| | if "residual" in new_key or "shortcut" in new_key: |
| | |
| | new_key = new_key.replace(".upsamples.", ".resnets.") |
| |
|
| | |
| | if ".residual.0.gamma" in new_key: |
| | new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| | elif ".residual.2.weight" in new_key: |
| | new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| | elif ".residual.2.bias" in new_key: |
| | new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| | elif ".residual.3.gamma" in new_key: |
| | new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| | elif ".residual.6.weight" in new_key: |
| | new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| | elif ".residual.6.bias" in new_key: |
| | new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| | elif ".shortcut.weight" in new_key: |
| | new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
| | elif ".shortcut.bias" in new_key: |
| | new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
| |
|
| | |
| | elif "resample" in new_key or "time_conv" in new_key: |
| | |
| | parts = new_key.split(".") |
| | |
| | |
| | if len(parts) >= 4 and parts[3] == "upsamples": |
| | |
| | new_parts = parts[:3] + ["upsampler"] + parts[5:] |
| | new_key = ".".join(new_parts) |
| |
|
| | new_state_dict[new_key] = value |
| | else: |
| | |
| | new_state_dict[key] = value |
| |
|
| | with init_empty_weights(): |
| | vae = AutoencoderKLWan(**vae22_diffusers_config) |
| | vae.load_state_dict(new_state_dict, strict=True, assign=True) |
| | return vae |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_type", type=str, default=None) |
| | parser.add_argument("--output_path", type=str, required=True) |
| | parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"]) |
| | return parser.parse_args() |
| |
|
| |
|
| | DTYPE_MAPPING = { |
| | "fp32": torch.float32, |
| | "fp16": torch.float16, |
| | "bf16": torch.bfloat16, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = get_args() |
| |
|
| | if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type: |
| | transformer = convert_transformer(args.model_type, stage="high_noise_model") |
| | transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") |
| | else: |
| | transformer = convert_transformer(args.model_type) |
| | transformer_2 = None |
| |
|
| | if "Wan2.2" in args.model_type and "TI2V" in args.model_type: |
| | vae = convert_vae_22() |
| | else: |
| | vae = convert_vae() |
| |
|
| | text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) |
| | tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") |
| | if "FLF2V" in args.model_type: |
| | flow_shift = 16.0 |
| | elif "TI2V" in args.model_type or "Animate" in args.model_type: |
| | flow_shift = 5.0 |
| | else: |
| | flow_shift = 3.0 |
| | scheduler = UniPCMultistepScheduler( |
| | prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift |
| | ) |
| |
|
| | |
| | if args.dtype != "none": |
| | dtype = DTYPE_MAPPING[args.dtype] |
| | transformer.to(dtype) |
| | if transformer_2 is not None: |
| | transformer_2.to(dtype) |
| |
|
| | if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type: |
| | pipe = WanImageToVideoPipeline( |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | boundary_ratio=0.9, |
| | ) |
| | elif "Wan2.2" and "T2V" in args.model_type: |
| | pipe = WanPipeline( |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | boundary_ratio=0.875, |
| | ) |
| | elif "Wan2.2" and "TI2V" in args.model_type: |
| | pipe = WanPipeline( |
| | transformer=transformer, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | expand_timesteps=True, |
| | ) |
| | elif "I2V" in args.model_type or "FLF2V" in args.model_type: |
| | image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| | "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 |
| | ) |
| | image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| | pipe = WanImageToVideoPipeline( |
| | transformer=transformer, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | image_encoder=image_encoder, |
| | image_processor=image_processor, |
| | ) |
| | elif "Wan2.2-VACE" in args.model_type: |
| | pipe = WanVACEPipeline( |
| | transformer=transformer, |
| | transformer_2=transformer_2, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | boundary_ratio=0.875, |
| | ) |
| | elif "Wan-VACE" in args.model_type: |
| | pipe = WanVACEPipeline( |
| | transformer=transformer, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | ) |
| | elif "Animate" in args.model_type: |
| | image_encoder = CLIPVisionModel.from_pretrained( |
| | "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 |
| | ) |
| | image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| |
|
| | pipe = WanAnimatePipeline( |
| | transformer=transformer, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | image_encoder=image_encoder, |
| | image_processor=image_processor, |
| | ) |
| | else: |
| | pipe = WanPipeline( |
| | transformer=transformer, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | ) |
| |
|
| | pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |