TalmajM commited on
Commit
175a754
·
verified ·
1 Parent(s): 7ac4a33

Upload folder using huggingface_hub

Browse files
convert_original_to_comfy.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone converter for Causal Forcing training checkpoints (.pt) to
4
+ ComfyUI-compatible .safetensors.
5
+
6
+ No dependency on the ComfyUI codebase — all conversion logic is self-contained.
7
+
8
+ Requirements:
9
+ pip install torch safetensors
10
+
11
+ Usage:
12
+ # Framewise model (uses EMA weights, num_frame_per_block=1 by default):
13
+ python convert_causal_forcing_standalone.py \
14
+ --input checkpoints/framewise/causal_forcing.pt \
15
+ --output models/causal_forcing_framewise.safetensors
16
+
17
+ # Chunkwise model (uses non-EMA weights, num_frame_per_block=3):
18
+ python convert_causal_forcing_standalone.py \
19
+ --input checkpoints/chunkwise/causal_forcing.pt \
20
+ --output models/causal_forcing_chunkwise.safetensors \
21
+ --no-ema --num-frame-per-block 3
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import logging
27
+
28
+ import torch
29
+ from safetensors.torch import save_file
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+ PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
34
+
35
+ _MODEL_KEY_PREFIXES = (
36
+ "blocks.", "head.", "patch_embedding.", "text_embedding.",
37
+ "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
38
+ )
39
+
40
+
41
+ def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
42
+ """
43
+ Extract and clean a Causal Forcing state dict from a training checkpoint.
44
+
45
+ Handles three checkpoint layouts:
46
+ 1. Training checkpoint with top-level generator_ema / generator keys
47
+ 2. Already-flattened state dict with model.* prefixes
48
+ 3. Already-converted ComfyUI state dict (bare model keys)
49
+
50
+ Returns a state dict with keys matching the CausalWanModel / WanModel layout.
51
+ """
52
+ if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
53
+ return state_dict
54
+
55
+ raw_sd = None
56
+ order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
57
+ for wrapper_key in order:
58
+ if wrapper_key in state_dict:
59
+ raw_sd = state_dict[wrapper_key]
60
+ log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd))
61
+ break
62
+
63
+ if raw_sd is None:
64
+ if any(k.startswith("model.") for k in state_dict):
65
+ raw_sd = state_dict
66
+ else:
67
+ raise KeyError(
68
+ f"Cannot detect Causal Forcing checkpoint layout. "
69
+ f"Top-level keys: {list(state_dict.keys())[:20]}"
70
+ )
71
+
72
+ out_sd = {}
73
+ for k, v in raw_sd.items():
74
+ new_k = k
75
+ for prefix in PREFIXES_TO_STRIP:
76
+ if new_k.startswith(prefix):
77
+ new_k = new_k[len(prefix):]
78
+ break
79
+ else:
80
+ if not new_k.startswith(_MODEL_KEY_PREFIXES):
81
+ log.debug("Skipping non-model key: %s", k)
82
+ continue
83
+ out_sd[new_k] = v
84
+
85
+ if "head.modulation" not in out_sd:
86
+ raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
87
+
88
+ return out_sd
89
+
90
+
91
+ def convert_and_save(input_path: str, output_path: str, use_ema: bool = True,
92
+ num_frame_per_block: int = 1):
93
+ print(f"Loading {input_path} ...")
94
+ state_dict = torch.load(input_path, map_location="cpu", weights_only=False)
95
+ out_sd = extract_state_dict(state_dict, use_ema=use_ema)
96
+ del state_dict
97
+
98
+ dim = out_sd["head.modulation"].shape[-1]
99
+ num_layers = 0
100
+ while f"blocks.{num_layers}.self_attn.q.weight" in out_sd:
101
+ num_layers += 1
102
+ print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}")
103
+
104
+ transformer_config = {"causal_ar": True}
105
+ if num_frame_per_block > 1:
106
+ transformer_config["num_frame_per_block"] = num_frame_per_block
107
+ metadata = {
108
+ "config": json.dumps({"transformer": transformer_config}),
109
+ }
110
+ save_file(out_sd, output_path, metadata=metadata)
111
+ print(f"Saved to {output_path}")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser(
116
+ description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)"
117
+ )
118
+ parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint")
119
+ parser.add_argument("--output", required=True, help="Output .safetensors path")
120
+ parser.add_argument(
121
+ "--no-ema", action="store_true",
122
+ help="Use 'generator' instead of 'generator_ema' (default: use EMA)",
123
+ )
124
+ parser.add_argument("--num-frame-per-block", type=int, default=1,
125
+ help="Frames per AR block (1=framewise, 3=chunkwise)")
126
+ parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging")
127
+ args = parser.parse_args()
128
+
129
+ logging.basicConfig(
130
+ level=logging.DEBUG if args.verbose else logging.INFO,
131
+ format="%(levelname)s: %(message)s",
132
+ )
133
+
134
+ convert_and_save(args.input, args.output, use_ema=not args.no_ema,
135
+ num_frame_per_block=args.num_frame_per_block)
download_original.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ curl -Lo causal_forcing-chunkwise.pt https://huggingface.co/zhuhz22/Causal-Forcing/resolve/main/chunkwise/causal_forcing.pt
split_files/diffusion_models/causal_forcing-chunkwise.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:885d2e448bf6de7daadf5d1f038da7bf353c8a4d09585862f013b95ab95d5117
3
+ size 5676070488