File size: 5,893 Bytes
31112ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
wan_convert_universal — Diffusers → Wan (T2V/I2V) canonical naming, single format.
- No model detection or flags.
- Shared keys get the same canonical names used by both T2V and I2V loaders.
- I2V extras are mapped only if present in the source.
API:
new_sd = convert_state_dict_universal(state_dict, cast_dtype=None)
convert_files_universal([...safetensors...], "out/model.safetensors", cast_dtype=None)
"""
import os, re
from typing import Dict, List, Optional
import torch
from safetensors.torch import safe_open, save_file
_RE_BLOCK = r"^blocks\.(\d+)\."
def _common_unet(src: str) -> str:
s = src
# attn1/attn2 → self_attn/cross_attn
s = re.sub(rf"{_RE_BLOCK}attn1\.", r"blocks.\1.self_attn.", s)
s = re.sub(rf"{_RE_BLOCK}attn2\.", r"blocks.\1.cross_attn.", s)
# to_q/k/v/out.0 → q/k/v/o (any *attn* path)
s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_q\.", r"\1q.", s)
s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_k\.", r"\1k.", s)
s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_v\.", r"\1v.", s)
s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_out\.0\.", r"\1o.", s)
# ffn.net.0.proj → ffn.0 ; ffn.net.2 → ffn.2
s = re.sub(rf"{_RE_BLOCK}ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.", s)
s = re.sub(rf"{_RE_BLOCK}ffn\.net\.2\.", r"blocks.\1.ffn.2.", s)
return s
def rename_key_universal(src: str) -> str:
"""
Canonical Wan naming (works for both T2V and I2V shared keys).
"""
s = _common_unet(src)
# Cross-attn image projections (only hit if present in the source)
s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.", s)
s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.", s)
s = re.sub(rf"{_RE_BLOCK}cross_attn\.norm_added_k\.", r"blocks.\1.cross_attn.norm_k_img.", s)
# Block-level canonical names
s = re.sub(rf"{_RE_BLOCK}scale_shift_table$", r"blocks.\1.modulation", s) # shared canonical
s = re.sub(rf"{_RE_BLOCK}norm2\b", r"blocks.\1.norm3", s) # shared canonical
# Conditioning MLPs (canonical heads)
s = re.sub(r"^condition_embedder\.text_embedder\.linear_1\.", r"text_embedding.0.", s)
s = re.sub(r"^condition_embedder\.text_embedder\.linear_2\.", r"text_embedding.2.", s)
s = re.sub(r"^condition_embedder\.time_embedder\.linear_1\.", r"time_embedding.0.", s)
s = re.sub(r"^condition_embedder\.time_embedder\.linear_2\.", r"time_embedding.2.", s)
# time projection single linear → time_projection.1.*
s = re.sub(r"^condition_embedder\.time_proj\.", r"time_projection.1.", s)
# Image conditioner "head" (corrected ordering!)
s = re.sub(r"^condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.", s) # LN (vector)
s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.", s) # Linear (matrix)
s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.", s) # Linear (matrix)
s = re.sub(r"^condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.", s) # LN (vector)
# Output head: canonical head.head.* ; top-level scale_shift_table → head.modulation
s = re.sub(r"^proj_out\.", r"head.head.", s)
if s == "scale_shift_table":
s = "head.modulation"
# patch_embedding.* remains unchanged
return s
# ---------- Public helpers (single canonical format) ----------
def convert_state_dict_universal(
state_dict: Dict[str, torch.Tensor],
cast_dtype: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
"""
Map a Diffusers-style state_dict to the canonical Wan naming (T2V/I2V shared).
"""
dtype = None
if cast_dtype:
cd = cast_dtype.lower().strip()
if cd in ("float16", "fp16", "half"): dtype = torch.float16
elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16
elif cd in ("float32", "fp32"): dtype = torch.float32
else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}")
out: Dict[str, torch.Tensor] = {}
for k, t in state_dict.items():
out[rename_key_universal(k)] = t.to(dtype) if dtype is not None else t
return out
def convert_files_universal(
input_paths: List[str],
out_safetensors: str,
cast_dtype: Optional[str] = None,
) -> None:
"""
Stream one/many Diffusers .safetensors into a single canonical Wan .safetensors.
"""
dtype = None
if cast_dtype:
cd = cast_dtype.lower().strip()
if cd in ("float16", "fp16", "half"): dtype = torch.float16
elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16
elif cd in ("float32", "fp32"): dtype = torch.float32
else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}")
os.makedirs(os.path.dirname(out_safetensors) or ".", exist_ok=True)
out: Dict[str, torch.Tensor] = {}
for p in input_paths:
with safe_open(p, framework="pt", device="cpu") as f:
for k in f.keys():
t = f.get_tensor(k)
if dtype is not None: t = t.to(dtype)
out[rename_key_universal(k)] = t
save_file(out, out_safetensors, metadata={
"format": "wan_universal",
"converted_from": "diffusers",
"script": "wan_convert_universal",
"cast_dtype": str(dtype) if dtype is not None else "unchanged",
})
def main():
from mmgp import safetensors2
files = [f"c:/temp/chrono/diffusion_pytorch_model-{i:05d}-of-00014.safetensors" for i in range(1, 15)]
new_sd = {}
for file in files:
sd = safetensors2.torch_load_file(file)
conv_sd = convert_state_dict_universal(sd) #, cast_dtype="bf16"
sd = None
new_sd.update(conv_sd)
safetensors2.torch_write_file(new_sd, "chrono.safetensors")
if __name__ == "__main__":
main()
|