| | """ |
| | wan_convert_universal β Diffusers β Wan (T2V/I2V) canonical naming, single format. |
| | |
| | - No model detection or flags. |
| | - Shared keys get the same canonical names used by both T2V and I2V loaders. |
| | - I2V extras are mapped only if present in the source. |
| | |
| | API: |
| | new_sd = convert_state_dict_universal(state_dict, cast_dtype=None) |
| | convert_files_universal([...safetensors...], "out/model.safetensors", cast_dtype=None) |
| | """ |
| |
|
| | import os, re |
| | from typing import Dict, List, Optional |
| | import torch |
| | from safetensors.torch import safe_open, save_file |
| |
|
| | _RE_BLOCK = r"^blocks\.(\d+)\." |
| |
|
| | def _common_unet(src: str) -> str: |
| | s = src |
| | |
| | s = re.sub(rf"{_RE_BLOCK}attn1\.", r"blocks.\1.self_attn.", s) |
| | s = re.sub(rf"{_RE_BLOCK}attn2\.", r"blocks.\1.cross_attn.", s) |
| | |
| | s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_q\.", r"\1q.", s) |
| | s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_k\.", r"\1k.", s) |
| | s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_v\.", r"\1v.", s) |
| | s = re.sub(r"(\b[^.\s]*attn[^.\s]*\b.*?\.)to_out\.0\.", r"\1o.", s) |
| | |
| | s = re.sub(rf"{_RE_BLOCK}ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.", s) |
| | s = re.sub(rf"{_RE_BLOCK}ffn\.net\.2\.", r"blocks.\1.ffn.2.", s) |
| | return s |
| |
|
| | def rename_key_universal(src: str) -> str: |
| | """ |
| | Canonical Wan naming (works for both T2V and I2V shared keys). |
| | """ |
| | s = _common_unet(src) |
| |
|
| | |
| | s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.", s) |
| | s = re.sub(rf"{_RE_BLOCK}cross_attn\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.", s) |
| | s = re.sub(rf"{_RE_BLOCK}cross_attn\.norm_added_k\.", r"blocks.\1.cross_attn.norm_k_img.", s) |
| |
|
| | |
| | s = re.sub(rf"{_RE_BLOCK}scale_shift_table$", r"blocks.\1.modulation", s) |
| | s = re.sub(rf"{_RE_BLOCK}norm2\b", r"blocks.\1.norm3", s) |
| |
|
| | |
| | s = re.sub(r"^condition_embedder\.text_embedder\.linear_1\.", r"text_embedding.0.", s) |
| | s = re.sub(r"^condition_embedder\.text_embedder\.linear_2\.", r"text_embedding.2.", s) |
| |
|
| | s = re.sub(r"^condition_embedder\.time_embedder\.linear_1\.", r"time_embedding.0.", s) |
| | s = re.sub(r"^condition_embedder\.time_embedder\.linear_2\.", r"time_embedding.2.", s) |
| |
|
| | |
| | s = re.sub(r"^condition_embedder\.time_proj\.", r"time_projection.1.", s) |
| |
|
| | |
| | s = re.sub(r"^condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.", s) |
| | s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.", s) |
| | s = re.sub(r"^condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.", s) |
| | s = re.sub(r"^condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.", s) |
| |
|
| | |
| | s = re.sub(r"^proj_out\.", r"head.head.", s) |
| | if s == "scale_shift_table": |
| | s = "head.modulation" |
| |
|
| | |
| | return s |
| |
|
| | |
| |
|
| | def convert_state_dict_universal( |
| | state_dict: Dict[str, torch.Tensor], |
| | cast_dtype: Optional[str] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Map a Diffusers-style state_dict to the canonical Wan naming (T2V/I2V shared). |
| | """ |
| | dtype = None |
| | if cast_dtype: |
| | cd = cast_dtype.lower().strip() |
| | if cd in ("float16", "fp16", "half"): dtype = torch.float16 |
| | elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16 |
| | elif cd in ("float32", "fp32"): dtype = torch.float32 |
| | else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}") |
| |
|
| | out: Dict[str, torch.Tensor] = {} |
| | for k, t in state_dict.items(): |
| | out[rename_key_universal(k)] = t.to(dtype) if dtype is not None else t |
| | return out |
| |
|
| | def convert_files_universal( |
| | input_paths: List[str], |
| | out_safetensors: str, |
| | cast_dtype: Optional[str] = None, |
| | ) -> None: |
| | """ |
| | Stream one/many Diffusers .safetensors into a single canonical Wan .safetensors. |
| | """ |
| | dtype = None |
| | if cast_dtype: |
| | cd = cast_dtype.lower().strip() |
| | if cd in ("float16", "fp16", "half"): dtype = torch.float16 |
| | elif cd in ("bfloat16", "bf16"): dtype = torch.bfloat16 |
| | elif cd in ("float32", "fp32"): dtype = torch.float32 |
| | else: raise ValueError(f"Unsupported cast_dtype: {cast_dtype}") |
| |
|
| | os.makedirs(os.path.dirname(out_safetensors) or ".", exist_ok=True) |
| | out: Dict[str, torch.Tensor] = {} |
| | for p in input_paths: |
| | with safe_open(p, framework="pt", device="cpu") as f: |
| | for k in f.keys(): |
| | t = f.get_tensor(k) |
| | if dtype is not None: t = t.to(dtype) |
| | out[rename_key_universal(k)] = t |
| |
|
| | save_file(out, out_safetensors, metadata={ |
| | "format": "wan_universal", |
| | "converted_from": "diffusers", |
| | "script": "wan_convert_universal", |
| | "cast_dtype": str(dtype) if dtype is not None else "unchanged", |
| | }) |
| |
|
| |
|
| |
|
| | def main(): |
| |
|
| |
|
| | from mmgp import safetensors2 |
| | files = [f"c:/temp/chrono/diffusion_pytorch_model-{i:05d}-of-00014.safetensors" for i in range(1, 15)] |
| | new_sd = {} |
| | for file in files: |
| | sd = safetensors2.torch_load_file(file) |
| | conv_sd = convert_state_dict_universal(sd) |
| | sd = None |
| | new_sd.update(conv_sd) |
| |
|
| | safetensors2.torch_write_file(new_sd, "chrono.safetensors") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|