TalmajM commited on
Commit
60005bd
·
verified ·
1 Parent(s): e0637a7

Upload folder using huggingface_hub

Browse files
convert_original_to_comfy.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone converter for Causal Forcing training checkpoints (.pt) to
4
+ ComfyUI-compatible .safetensors.
5
+
6
+ No dependency on the ComfyUI codebase — all conversion logic is self-contained.
7
+
8
+ Requirements:
9
+ pip install torch safetensors
10
+
11
+ Usage:
12
+ python convert_causal_forcing_standalone.py \
13
+ --input checkpoints/framewise/causal_forcing.pt \
14
+ --output models/causal_forcing_framewise.safetensors
15
+
16
+ python convert_causal_forcing_standalone.py \
17
+ --input checkpoints/framewise/causal_forcing.pt \
18
+ --output models/causal_forcing_framewise.safetensors \
19
+ --no-ema
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import logging
25
+
26
+ import torch
27
+ from safetensors.torch import save_file
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+ PREFIXES_TO_STRIP = ["model._fsdp_wrapped_module.", "model."]
32
+
33
+ _MODEL_KEY_PREFIXES = (
34
+ "blocks.", "head.", "patch_embedding.", "text_embedding.",
35
+ "time_embedding.", "time_projection.", "img_emb.", "rope_embedder.",
36
+ )
37
+
38
+
39
+ def extract_state_dict(state_dict: dict, use_ema: bool = True) -> dict:
40
+ """
41
+ Extract and clean a Causal Forcing state dict from a training checkpoint.
42
+
43
+ Handles three checkpoint layouts:
44
+ 1. Training checkpoint with top-level generator_ema / generator keys
45
+ 2. Already-flattened state dict with model.* prefixes
46
+ 3. Already-converted ComfyUI state dict (bare model keys)
47
+
48
+ Returns a state dict with keys matching the CausalWanModel / WanModel layout.
49
+ """
50
+ if "head.modulation" in state_dict and "blocks.0.self_attn.q.weight" in state_dict:
51
+ return state_dict
52
+
53
+ raw_sd = None
54
+ order = ["generator_ema", "generator"] if use_ema else ["generator", "generator_ema"]
55
+ for wrapper_key in order:
56
+ if wrapper_key in state_dict:
57
+ raw_sd = state_dict[wrapper_key]
58
+ log.info("Extracted '%s' with %d keys", wrapper_key, len(raw_sd))
59
+ break
60
+
61
+ if raw_sd is None:
62
+ if any(k.startswith("model.") for k in state_dict):
63
+ raw_sd = state_dict
64
+ else:
65
+ raise KeyError(
66
+ f"Cannot detect Causal Forcing checkpoint layout. "
67
+ f"Top-level keys: {list(state_dict.keys())[:20]}"
68
+ )
69
+
70
+ out_sd = {}
71
+ for k, v in raw_sd.items():
72
+ new_k = k
73
+ for prefix in PREFIXES_TO_STRIP:
74
+ if new_k.startswith(prefix):
75
+ new_k = new_k[len(prefix):]
76
+ break
77
+ else:
78
+ if not new_k.startswith(_MODEL_KEY_PREFIXES):
79
+ log.debug("Skipping non-model key: %s", k)
80
+ continue
81
+ out_sd[new_k] = v
82
+
83
+ if "head.modulation" not in out_sd:
84
+ raise ValueError("Conversion failed: 'head.modulation' not found in output keys")
85
+
86
+ return out_sd
87
+
88
+
89
+ def convert_and_save(input_path: str, output_path: str, use_ema: bool = True):
90
+ print(f"Loading {input_path} ...")
91
+ state_dict = torch.load(input_path, map_location="cpu", weights_only=False)
92
+ out_sd = extract_state_dict(state_dict, use_ema=use_ema)
93
+ del state_dict
94
+
95
+ dim = out_sd["head.modulation"].shape[-1]
96
+ num_layers = 0
97
+ while f"blocks.{num_layers}.self_attn.q.weight" in out_sd:
98
+ num_layers += 1
99
+ print(f"Detected model: dim={dim}, num_layers={num_layers}, keys={len(out_sd)}")
100
+
101
+ metadata = {
102
+ "config": json.dumps({"transformer": {"causal_ar": True}}),
103
+ }
104
+ save_file(out_sd, output_path, metadata=metadata)
105
+ print(f"Saved to {output_path}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser(
110
+ description="Convert Causal Forcing checkpoint to ComfyUI safetensors (standalone)"
111
+ )
112
+ parser.add_argument("--input", required=True, help="Path to the training .pt checkpoint")
113
+ parser.add_argument("--output", required=True, help="Output .safetensors path")
114
+ parser.add_argument(
115
+ "--no-ema", action="store_true",
116
+ help="Use 'generator' instead of 'generator_ema' (default: use EMA)",
117
+ )
118
+ parser.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging")
119
+ args = parser.parse_args()
120
+
121
+ logging.basicConfig(
122
+ level=logging.DEBUG if args.verbose else logging.INFO,
123
+ format="%(levelname)s: %(message)s",
124
+ )
125
+
126
+ convert_and_save(args.input, args.output, use_ema=not args.no_ema)
download_original.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ curl -Lo causal_forcing-framewise.pt https://huggingface.co/zhuhz22/Causal-Forcing/resolve/main/framewise/causal_forcing.pt
split_files/diffusion_models/causal_forcing-framewise.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:408c67a8c6725756f5be2c5cf2d5c584c15dd147f5a0c458be62dcb3efb78477
3
+ size 5676070464