File size: 4,879 Bytes
9db0d7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python3
"""
Standalone converter for Causal Forcing training checkpoints (.pt) to
ComfyUI-compatible .safetensors.
No dependency on the ComfyUI codebase — all conversion logic is self-contained.
Requirements:
pip install torch safetensors
Usage:
# Framewise model (uses EMA weights, num_frame_per_block=1 by default):
python convert_causal_forcing_standalone.py \
--input checkpoints/framewise/causal_forcing.pt \
--output models/causal_forcing_framewise.safetensors
# Chunkwise model (uses non-EMA weights, num_frame_per_block=3):
python convert_causal_forcing_standalone.py \
--input checkpoints/chunkwise/causal_forcing.pt \
--output models/causal_forcing_chunkwise.safetensors \
--no-ema --num-frame-per-block 3
"""
import argparse
import json
import logging
import torch
from safetensors.torch import save_file
log = logging.getLogger(__name__)
PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
_MODEL_KEY_PREFIXES = (
"blocks.", "head.", "patch_embedding.", "text_embedding.",
"time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
)
def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
"""
Extract and clean a Causal Forcing state dict from a training checkpoint.
Handles three checkpoint layouts:
1. Training checkpoint with top-level generator_ema / generator keys
2. Already-flattened state dict with model.* prefixes
3. Already-converted ComfyUI state dict (bare model keys)
Returns a state dict with keys matching the CausalWanModel / WanModel layout.
"""
if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
return state_dict
raw_sd = None
order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
for wrapper_key in order:
if wrapper_key in state_dict:
raw_sd = state_dict[wrapper_key]
log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd))
break
if raw_sd is None:
if any(k.startswith("model.") for k in state_dict):
raw_sd = state_dict
else:
raise KeyError(
f"Cannot detect Causal Forcing checkpoint layout. "
f"Top-level keys: {list(state_dict.keys())[:20]}"
)
out_sd = {}
for k, v in raw_sd.items():
new_k = k
for prefix in PREFIXES_TO_STRIP:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
break
else:
if not new_k.startswith(_MODEL_KEY_PREFIXES):
log.debug("Skipping non-model key: %s", k)
continue
out_sd[new_k] = v
if "head.modulation" not in out_sd:
raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
return out_sd
def convert_and_save(input_path: str, output_path: str, use_ema: bool = True,
num_frame_per_block: int = 1):
print(f"Loading {input_path} ...")
state_dict = torch.load(input_path, map_location="cpu", weights_only=False)
out_sd = extract_state_dict(state_dict, use_ema=use_ema)
del state_dict
dim = out_sd["head.modulation"].shape[-1]
num_layers = 0
while f"blocks.{num_layers}.self_attn.q.weight" in out_sd:
num_layers += 1
print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}")
transformer_config = {"causal_ar": True}
if num_frame_per_block > 1:
transformer_config["num_frame_per_block"] = num_frame_per_block
metadata = {
"config": json.dumps({"transformer": transformer_config}),
}
save_file(out_sd, output_path, metadata=metadata)
print(f"Saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)"
)
parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint")
parser.add_argument("--output", required=True, help="Output .safetensors path")
parser.add_argument(
"--no-ema", action="store_true",
help="Use 'generator' instead of 'generator_ema' (default: use EMA)",
)
parser.add_argument("--num-frame-per-block", type=int, default=1,
help="Frames per AR block (1=framewise, 3=chunkwise)")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(levelname)s: %(message)s",
)
convert_and_save(args.input, args.output, use_ema=not args.no_ema,
num_frame_per_block=args.num_frame_per_block)
|