| |
| """ |
| Standalone converter for Causal Forcing training checkpoints (.pt) to |
| ComfyUI-compatible .safetensors. |
| |
| No dependency on the ComfyUI codebase — all conversion logic is self-contained. |
| |
| Requirements: |
| pip install torch safetensors |
| |
| Usage: |
| # Framewise model (uses EMA weights, num_frame_per_block=1 by default): |
| python convert_causal_forcing_standalone.py \ |
| --input checkpoints/framewise/causal_forcing.pt \ |
| --output models/causal_forcing_framewise.safetensors |
| |
| # Chunkwise model (uses non-EMA weights, num_frame_per_block=3): |
| python convert_causal_forcing_standalone.py \ |
| --input checkpoints/chunkwise/causal_forcing.pt \ |
| --output models/causal_forcing_chunkwise.safetensors \ |
| --no-ema --num-frame-per-block 3 |
| """ |
|
|
| import argparse |
| import json |
| import logging |
|
|
| import torch |
| from safetensors.torch import save_file |
|
|
| log = logging.getLogger(__name__) |
|
|
| PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."] |
|
|
| _MODEL_KEY_PREFIXES = ( |
| "blocks.", "head.", "patch_embedding.", "text_embedding.", |
| "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.", |
| ) |
|
|
|
|
| def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict: |
| """ |
| Extract and clean a Causal Forcing state dict from a training checkpoint. |
| |
| Handles three checkpoint layouts: |
| 1. Training checkpoint with top-level generator_ema / generator keys |
| 2. Already-flattened state dict with model.* prefixes |
| 3. Already-converted ComfyUI state dict (bare model keys) |
| |
| Returns a state dict with keys matching the CausalWanModel / WanModel layout. |
| """ |
| if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict: |
| return state_dict |
|
|
| raw_sd = None |
| order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"] |
| for wrapper_key in order: |
| if wrapper_key in state_dict: |
| raw_sd = state_dict[wrapper_key] |
| log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd)) |
| break |
|
|
| if raw_sd is None: |
| if any(k.startswith("model.") for k in state_dict): |
| raw_sd = state_dict |
| else: |
| raise KeyError( |
| f"Cannot detect Causal Forcing checkpoint layout. " |
| f"Top-level keys: {list(state_dict.keys())[:20]}" |
| ) |
|
|
| out_sd = {} |
| for k, v in raw_sd.items(): |
| new_k = k |
| for prefix in PREFIXES_TO_STRIP: |
| if new_k.startswith(prefix): |
| new_k = new_k[len(prefix):] |
| break |
| else: |
| if not new_k.startswith(_MODEL_KEY_PREFIXES): |
| log.debug("Skipping non-model key: %s", k) |
| continue |
| out_sd[new_k] = v |
|
|
| if "head.modulation" not in out_sd: |
| raise ValueError("Conversion failed: 'head.modulation' not found in output keys") |
|
|
| return out_sd |
|
|
|
|
| def convert_and_save(input_path: str, output_path: str, use_ema: bool = True, |
| num_frame_per_block: int = 1): |
| print(f"Loading {input_path} ...") |
| state_dict = torch.load(input_path, map_location="cpu", weights_only=False) |
| out_sd = extract_state_dict(state_dict, use_ema=use_ema) |
| del state_dict |
|
|
| dim = out_sd["head.modulation"].shape[-1] |
| num_layers = 0 |
| while f"blocks.{num_layers}.self_attn.q.weight" in out_sd: |
| num_layers += 1 |
| print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}") |
|
|
| transformer_config = {"causal_ar": True} |
| if num_frame_per_block > 1: |
| transformer_config["num_frame_per_block"] = num_frame_per_block |
| metadata = { |
| "config": json.dumps({"transformer": transformer_config}), |
| } |
| save_file(out_sd, output_path, metadata=metadata) |
| print(f"Saved to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)" |
| ) |
| parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint") |
| parser.add_argument("--output", required=True, help="Output .safetensors path") |
| parser.add_argument( |
| "--no-ema", action="store_true", |
| help="Use 'generator' instead of 'generator_ema' (default: use EMA)", |
| ) |
| parser.add_argument("--num-frame-per-block", type=int, default=1, |
| help="Frames per AR block (1=framewise, 3=chunkwise)") |
| parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging") |
| args = parser.parse_args() |
|
|
| logging.basicConfig( |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(levelname)s: %(message)s", |
| ) |
|
|
| convert_and_save(args.input, args.output, use_ema=not args.no_ema, |
| num_frame_per_block=args.num_frame_per_block) |
|
|