| """ |
| wan_convert_universal — Diffusers → Wan (T2V/I2V) canonical naming, single format. |
| |
| - No model detection or flags. |
| - Shared keys get the same canonical names used by both T2V and I2V loaders. |
| - I2V extras are mapped only if present in the source. |
| |
| API: |
| new_sd = convert_state_dict_universal(state_dict, cast_dtype=None) |
| convert_files_universal([...safetensors...], "out/model.safetensors", cast_dtype=None) |
| """ |
|
|
| import os, re |
| from typing import Dict, List, Optional |
| import torch |
| from safetensors.torch import safe_open, save_file |
|
|
| _RE_BLOCK = r"^blocks\.(\d+)\." |
|
|
| def _common_unet(src: str) -> str: |
| s = src |
| |
| s = re.sub(rf"{_RE_BLOCK}attn1\.", r"blocks.\1.self_attn.", s) |
| s = re.sub(rf"{_RE_BLOCK}attn2\.", r"blocks.\1.cross_attn.", s) |
| |
| s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_q\.", r"\1q.", s) |
| s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_k\.", r"\1k.", s) |
| s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_v\.", r"\1v.", s) |
| s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_out\.0\.", r"\1o.", s) |
| |
| s = re.sub(rf"{_RE_BLOCK}ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.", s) |
| s = re.sub(rf"{_RE_BLOCK}ffn\.net\.2\.", r"blocks.\1.ffn.2.", s) |
| return s |
|
|
| def rename_key_universal(src: str) -> str: |
| """ |
| Canonical Wan naming (works for both T2V and I2V shared keys). |
| """ |
| s = _common_unet(src) |
|
|
| |
| s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.", s) |
| s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.", s) |
| s = re.sub(rf"{_RE_BLOCK}cross_attn\.norm_added_k\.", r"blocks.\1.cross_attn.norm_k_img.", s) |
|
|
| |
| s = re.sub(rf"{_RE_BLOCK}scale_shift_table$", r"blocks.\1.modulation", s) |
| s = re.sub(rf"{_RE_BLOCK}norm2\b", r"blocks.\1.norm3", s) |
|
|
| |
| s = re.sub(r"^condition_embedder\.text_embedder\.linear_1\.", r"text_embedding.0.", s) |
| s = re.sub(r"^condition_embedder\.text_embedder\.linear_2\.", r"text_embedding.2.", s) |
|
|
| s = re.sub(r"^condition_embedder\.time_embedder\.linear_1\.", r"time_embedding.0.", s) |
| s = re.sub(r"^condition_embedder\.time_embedder\.linear_2\.", r"time_embedding.2.", s) |
|
|
| |
| s = re.sub(r"^condition_embedder\.time_proj\.", r"time_projection.1.", s) |
|
|
| |
| s = re.sub(r"^condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.", s) |
| s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.", s) |
| s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.", s) |
| s = re.sub(r"^condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.", s) |
|
|
| |
| s = re.sub(r"^proj_out\.", r"head.head.", s) |
| if s == "scale_shift_table": |
| s = "head.modulation" |
|
|
| |
| return s |
|
|
| |
|
|
| def convert_state_dict_universal( |
| state_dict: Dict[str, torch.Tensor], |
| cast_dtype: Optional[str] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Map a Diffusers-style state_dict to the canonical Wan naming (T2V/I2V shared). |
| """ |
| dtype = None |
| if cast_dtype: |
| cd = cast_dtype.lower().strip() |
| if cd in ("float16", "fp16", "half"): dtype = torch.float16 |
| elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16 |
| elif cd in ("float32", "fp32"): dtype = torch.float32 |
| else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}") |
|
|
| out: Dict[str, torch.Tensor] = {} |
| for k, t in state_dict.items(): |
| out[rename_key_universal(k)] = t.to(dtype) if dtype is not None else t |
| return out |
|
|
| def convert_files_universal( |
| input_paths: List[str], |
| out_safetensors: str, |
| cast_dtype: Optional[str] = None, |
| ) -> None: |
| """ |
| Stream one/many Diffusers .safetensors into a single canonical Wan .safetensors. |
| """ |
| dtype = None |
| if cast_dtype: |
| cd = cast_dtype.lower().strip() |
| if cd in ("float16", "fp16", "half"): dtype = torch.float16 |
| elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16 |
| elif cd in ("float32", "fp32"): dtype = torch.float32 |
| else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}") |
|
|
| os.makedirs(os.path.dirname(out_safetensors) or ".", exist_ok=True) |
| out: Dict[str, torch.Tensor] = {} |
| for p in input_paths: |
| with safe_open(p, framework="pt", device="cpu") as f: |
| for k in f.keys(): |
| t = f.get_tensor(k) |
| if dtype is not None: t = t.to(dtype) |
| out[rename_key_universal(k)] = t |
|
|
| save_file(out, out_safetensors, metadata={ |
| "format": "wan_universal", |
| "converted_from": "diffusers", |
| "script": "wan_convert_universal", |
| "cast_dtype": str(dtype) if dtype is not None else "unchanged", |
| }) |
|
|
|
|
|
|
| def main(): |
|
|
|
|
| from mmgp import safetensors2 |
| files = [f"c:/temp/chrono/diffusion_pytorch_model-{i:05d}-of-00014.safetensors" for i in range(1, 15)] |
| new_sd = {} |
| for file in files: |
| sd = safetensors2.torch_load_file(file) |
| conv_sd = convert_state_dict_universal(sd) |
| sd = None |
| new_sd.update(conv_sd) |
|
|
| safetensors2.torch_write_file(new_sd, "chrono.safetensors") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|