File size: 4,879 Bytes
9db0d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""
Standalone converter for Causal Forcing training checkpoints (.pt) to
ComfyUI-compatible .safetensors.

No dependency on the ComfyUI codebase — all conversion logic is self-contained.

Requirements:
    pip install torch safetensors

Usage:
    # Framewise model (uses EMA weights, num_frame_per_block=1 by default):
    python convert_causal_forcing_standalone.py \
        --input checkpoints/framewise/causal_forcing.pt \
        --output models/causal_forcing_framewise.safetensors

    # Chunkwise model (uses non-EMA weights, num_frame_per_block=3):
    python convert_causal_forcing_standalone.py \
        --input checkpoints/chunkwise/causal_forcing.pt \
        --output models/causal_forcing_chunkwise.safetensors \
        --no-ema --num-frame-per-block 3
"""

import argparse
import json
import logging

import torch
from safetensors.torch import save_file

log = logging.getLogger(__name__)

PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]

_MODEL_KEY_PREFIXES = (
    "blocks.", "head.", "patch_embedding.", "text_embedding.",
    "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
)


def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
    """
    Extract and clean a Causal Forcing state dict from a training checkpoint.

    Handles three checkpoint layouts:
      1. Training checkpoint with top-level generator_ema / generator keys
      2. Already-flattened state dict with model.* prefixes
      3. Already-converted ComfyUI state dict (bare model keys)

    Returns a state dict with keys matching the CausalWanModel / WanModel layout.
    """
    if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
        return state_dict

    raw_sd = None
    order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
    for wrapper_key in order:
        if wrapper_key in state_dict:
            raw_sd = state_dict[wrapper_key]
            log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd))
            break

    if raw_sd is None:
        if any(k.startswith("model.") for k in state_dict):
            raw_sd = state_dict
        else:
            raise KeyError(
                f"Cannot detect Causal Forcing checkpoint layout. "
                f"Top-level keys: {list(state_dict.keys())[:20]}"
            )

    out_sd = {}
    for k, v in raw_sd.items():
        new_k = k
        for prefix in PREFIXES_TO_STRIP:
            if new_k.startswith(prefix):
                new_k = new_k[len(prefix):]
                break
        else:
            if not new_k.startswith(_MODEL_KEY_PREFIXES):
                log.debug("Skipping non-model key: %s", k)
                continue
        out_sd[new_k] = v

    if "head.modulation" not in out_sd:
        raise ValueError("Conversion failed: 'head.modulation' not found in output keys")

    return out_sd


def convert_and_save(input_path: str, output_path: str, use_ema: bool = True,
                     num_frame_per_block: int = 1):
    print(f"Loading {input_path} ...")
    state_dict = torch.load(input_path, map_location="cpu", weights_only=False)
    out_sd = extract_state_dict(state_dict, use_ema=use_ema)
    del state_dict

    dim = out_sd["head.modulation"].shape[-1]
    num_layers = 0
    while f"blocks.{num_layers}.self_attn.q.weight" in out_sd:
        num_layers += 1
    print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}")

    transformer_config = {"causal_ar": True}
    if num_frame_per_block > 1:
        transformer_config["num_frame_per_block"] = num_frame_per_block
    metadata = {
        "config": json.dumps({"transformer": transformer_config}),
    }
    save_file(out_sd, output_path, metadata=metadata)
    print(f"Saved to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)"
    )
    parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint")
    parser.add_argument("--output", required=True, help="Output .safetensors path")
    parser.add_argument(
        "--no-ema", action="store_true",
        help="Use 'generator' instead of 'generator_ema' (default: use EMA)",
    )
    parser.add_argument("--num-frame-per-block", type=int, default=1,
                        help="Frames per AR block (1=framewise, 3=chunkwise)")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s: %(message)s",
    )

    convert_and_save(args.input, args.output, use_ema=not args.no_ema,
                     num_frame_per_block=args.num_frame_per_block)