self_forcing_ComfyUI_repackaged / convert_original_to_comfy.py
TalmajM's picture
Upload folder using huggingface_hub
9db0d7b verified
#!/usr/bin/env python3
"""
Standalone converter for Causal Forcing training checkpoints (.pt) to
ComfyUI-compatible .safetensors.
No dependency on the ComfyUI codebase — all conversion logic is self-contained.
Requirements:
pip install torch safetensors
Usage:
# Framewise model (uses EMA weights, num_frame_per_block=1 by default):
python convert_causal_forcing_standalone.py \
--input checkpoints/framewise/causal_forcing.pt \
--output models/causal_forcing_framewise.safetensors
# Chunkwise model (uses non-EMA weights, num_frame_per_block=3):
python convert_causal_forcing_standalone.py \
--input checkpoints/chunkwise/causal_forcing.pt \
--output models/causal_forcing_chunkwise.safetensors \
--no-ema --num-frame-per-block 3
"""
import argparse
import json
import logging
import torch
from safetensors.torch import save_file
log = logging.getLogger(__name__)
PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
_MODEL_KEY_PREFIXES = (
"blocks.", "head.", "patch_embedding.", "text_embedding.",
"time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
)
def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
"""
Extract and clean a Causal Forcing state dict from a training checkpoint.
Handles three checkpoint layouts:
1. Training checkpoint with top-level generator_ema / generator keys
2. Already-flattened state dict with model.* prefixes
3. Already-converted ComfyUI state dict (bare model keys)
Returns a state dict with keys matching the CausalWanModel / WanModel layout.
"""
if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
return state_dict
raw_sd = None
order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
for wrapper_key in order:
if wrapper_key in state_dict:
raw_sd = state_dict[wrapper_key]
log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd))
break
if raw_sd is None:
if any(k.startswith("model.") for k in state_dict):
raw_sd = state_dict
else:
raise KeyError(
f"Cannot detect Causal Forcing checkpoint layout. "
f"Top-level keys: {list(state_dict.keys())[:20]}"
)
out_sd = {}
for k, v in raw_sd.items():
new_k = k
for prefix in PREFIXES_TO_STRIP:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
break
else:
if not new_k.startswith(_MODEL_KEY_PREFIXES):
log.debug("Skipping non-model key: %s", k)
continue
out_sd[new_k] = v
if "head.modulation" not in out_sd:
raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
return out_sd
def convert_and_save(input_path: str, output_path: str, use_ema: bool = True,
num_frame_per_block: int = 1):
print(f"Loading {input_path} ...")
state_dict = torch.load(input_path, map_location="cpu", weights_only=False)
out_sd = extract_state_dict(state_dict, use_ema=use_ema)
del state_dict
dim = out_sd["head.modulation"].shape[-1]
num_layers = 0
while f"blocks.{num_layers}.self_attn.q.weight" in out_sd:
num_layers += 1
print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}")
transformer_config = {"causal_ar": True}
if num_frame_per_block > 1:
transformer_config["num_frame_per_block"] = num_frame_per_block
metadata = {
"config": json.dumps({"transformer": transformer_config}),
}
save_file(out_sd, output_path, metadata=metadata)
print(f"Saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)"
)
parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint")
parser.add_argument("--output", required=True, help="Output .safetensors path")
parser.add_argument(
"--no-ema", action="store_true",
help="Use 'generator' instead of 'generator_ema' (default: use EMA)",
)
parser.add_argument("--num-frame-per-block", type=int, default=1,
help="Frames per AR block (1=framewise, 3=chunkwise)")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(levelname)s: %(message)s",
)
convert_and_save(args.input, args.output, use_ema=not args.no_ema,
num_frame_per_block=args.num_frame_per_block)