ColabWan / models /qwen /convert_diffusers_qwen_vae.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
15.4 kB
from typing import Mapping, Dict
import torch
def convert_state_dict(state_dict: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a ComfyUI-formatted Wan/Qwen VAE state_dict to Diffusers format.
Input: dict-like mapping from str -> torch.Tensor (e.g. loaded from safetensors)
Output: new dict with Diffusers key names (no mutation of the input)
"""
# Exact key remaps for middle resnets (encoder/decoder)
middle_key_mapping = {
# Encoder middle resnets
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle resnets
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Exact key remaps for the mid attention blocks (encoder/decoder)
attention_mapping = {
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Heads (norm_out / conv_out) for encoder/decoder
head_mapping = {
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Latent quantization bridges
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
out: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
# 1) Direct dictionary remaps
if key in middle_key_mapping:
out[middle_key_mapping[key]] = value
continue
if key in attention_mapping:
out[attention_mapping[key]] = value
continue
if key in head_mapping:
out[head_mapping[key]] = value
continue
if key in quant_mapping:
out[quant_mapping[key]] = value
continue
# 2) Conv-in aliases for encoder/decoder
if key == "encoder.conv1.weight":
out["encoder.conv_in.weight"] = value
continue
if key == "encoder.conv1.bias":
out["encoder.conv_in.bias"] = value
continue
if key == "decoder.conv1.weight":
out["decoder.conv_in.weight"] = value
continue
if key == "decoder.conv1.bias":
out["decoder.conv_in.bias"] = value
continue
# 3) Encoder down path (downsamples.* -> down_blocks.*)
if key.startswith("encoder.downsamples."):
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Residual -> (norm1/conv1/norm2/conv2), shortcut passthrough
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
out[new_key] = value
continue
# 4) Decoder up path (upsamples.* -> up_blocks.*)
if key.startswith("decoder.upsamples."):
parts = key.split(".")
# format: decoder.upsamples.{block_idx}.(residual|resample|time_conv|shortcut)...
if len(parts) >= 3 and parts[2].isdigit():
block_idx = int(parts[2])
# 4a) Residual groups: map flat indices -> (up_block_id, resnet_id)
if "residual" in key:
if block_idx in (0, 1, 2):
up_block_id, resnet_id = 0, block_idx
elif block_idx in (4, 5, 6):
up_block_id, resnet_id = 1, block_idx - 4
elif block_idx in (8, 9, 10):
up_block_id, resnet_id = 2, block_idx - 8
elif block_idx in (12, 13, 14):
up_block_id, resnet_id = 3, block_idx - 12
else:
# keep unmapped residuals as-is
out[key] = value
continue
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{up_block_id}.resnets.{resnet_id}.conv2.weight"
else:
new_key = key
out[new_key] = value
continue
# 4b) Shortcut convs
if ".shortcut." in key:
if block_idx == 4:
# special-case first shortcut in block 1 -> resnets.0.conv_shortcut
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
out[new_key] = value
continue
# 4c) Upsamplers & time conv placement (the 3,7,11 pattern)
if (".resample." in key) or (".time_conv." in key):
if block_idx == 3:
new_key = key.replace("decoder.upsamples.3", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace("decoder.upsamples.7", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace("decoder.upsamples.11", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
out[new_key] = value
continue
# default: just change the container name
out[key.replace("decoder.upsamples.", "decoder.up_blocks.")] = value
continue
# 5) Fallback: preserve anything not covered above
out[key] = value
return out
def convert_diffusers_state_dict(state_dict: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert a diffusers-formatted Qwen/Wan VAE checkpoint to WanGP's native WanVAE key names."""
out: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if key == "quant_conv.weight":
out["conv1.weight"] = value
continue
if key == "quant_conv.bias":
out["conv1.bias"] = value
continue
if key == "post_quant_conv.weight":
out["conv2.weight"] = value
continue
if key == "post_quant_conv.bias":
out["conv2.bias"] = value
continue
if key == "encoder.conv_in.weight":
out["encoder.conv1.weight"] = value
continue
if key == "encoder.conv_in.bias":
out["encoder.conv1.bias"] = value
continue
if key == "decoder.conv_in.weight":
out["decoder.conv1.weight"] = value
continue
if key == "decoder.conv_in.bias":
out["decoder.conv1.bias"] = value
continue
if key == "encoder.norm_out.gamma":
out["encoder.head.0.gamma"] = value
continue
if key == "encoder.conv_out.weight":
out["encoder.head.2.weight"] = value
continue
if key == "encoder.conv_out.bias":
out["encoder.head.2.bias"] = value
continue
if key == "decoder.norm_out.gamma":
out["decoder.head.0.gamma"] = value
continue
if key == "decoder.conv_out.weight":
out["decoder.head.2.weight"] = value
continue
if key == "decoder.conv_out.bias":
out["decoder.head.2.bias"] = value
continue
new_key = key
for side in ("encoder", "decoder"):
new_key = new_key.replace(f"{side}.mid_block.resnets.0.norm1.gamma", f"{side}.middle.0.residual.0.gamma")
new_key = new_key.replace(f"{side}.mid_block.resnets.0.conv1.", f"{side}.middle.0.residual.2.")
new_key = new_key.replace(f"{side}.mid_block.resnets.0.norm2.gamma", f"{side}.middle.0.residual.3.gamma")
new_key = new_key.replace(f"{side}.mid_block.resnets.0.conv2.", f"{side}.middle.0.residual.6.")
new_key = new_key.replace(f"{side}.mid_block.attentions.0.norm.gamma", f"{side}.middle.1.norm.gamma")
new_key = new_key.replace(f"{side}.mid_block.attentions.0.to_qkv.", f"{side}.middle.1.to_qkv.")
new_key = new_key.replace(f"{side}.mid_block.attentions.0.proj.", f"{side}.middle.1.proj.")
new_key = new_key.replace(f"{side}.mid_block.resnets.1.norm1.gamma", f"{side}.middle.2.residual.0.gamma")
new_key = new_key.replace(f"{side}.mid_block.resnets.1.conv1.", f"{side}.middle.2.residual.2.")
new_key = new_key.replace(f"{side}.mid_block.resnets.1.norm2.gamma", f"{side}.middle.2.residual.3.gamma")
new_key = new_key.replace(f"{side}.mid_block.resnets.1.conv2.", f"{side}.middle.2.residual.6.")
if new_key != key:
out[new_key] = value
continue
if key.startswith("encoder.down_blocks."):
new_key = key.replace("encoder.down_blocks.", "encoder.downsamples.")
new_key = new_key.replace(".norm1.gamma", ".residual.0.gamma")
new_key = new_key.replace(".conv1.", ".residual.2.")
new_key = new_key.replace(".norm2.gamma", ".residual.3.gamma")
new_key = new_key.replace(".conv2.", ".residual.6.")
new_key = new_key.replace(".conv_shortcut.", ".shortcut.")
out[new_key] = value
continue
if key.startswith("decoder.up_blocks."):
parts = key.split(".")
if len(parts) >= 6 and parts[2].isdigit() and parts[3] == "resnets" and parts[4].isdigit():
block_idx = int(parts[2]) * 4 + int(parts[4])
new_key = ".".join(["decoder", "upsamples", str(block_idx), *parts[5:]])
new_key = new_key.replace(".norm1.gamma", ".residual.0.gamma")
new_key = new_key.replace(".conv1.", ".residual.2.")
new_key = new_key.replace(".norm2.gamma", ".residual.3.gamma")
new_key = new_key.replace(".conv2.", ".residual.6.")
new_key = new_key.replace(".conv_shortcut.", ".shortcut.")
out[new_key] = value
continue
if ".upsamplers.0." in key:
upsampler_idx = {"0": "3", "1": "7", "2": "11"}.get(parts[2])
if upsampler_idx is not None:
out[key.replace(f"decoder.up_blocks.{parts[2]}.upsamplers.0", f"decoder.upsamples.{upsampler_idx}")] = value
continue
out[key] = value
return out