TalmajM commited on
Commit
0e8dde1
·
verified ·
1 Parent(s): 59c9a06

Upload folder using huggingface_hub

Browse files
convert_original_to_comfy.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert LongCat-Image transformer weights from HuggingFace Diffusers format
4
+ to ComfyUI format.
5
+
6
+ Usage:
7
+ python conversion.py input.safetensors output.safetensors
8
+
9
+ The input file is the Diffusers-format transformer, typically:
10
+ meituan-longcat/LongCat-Image/transformer/diffusion_pytorch_model.safetensors
11
+
12
+ The output file will contain ComfyUI-format keys with fused QKV tensors,
13
+ ready for zero-copy loading via UNETLoader.
14
+ """
15
+
16
+ import argparse
17
+ import torch
18
+ import logging
19
+ from safetensors.torch import load_file, save_file
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def convert_longcat_image(state_dict):
26
+ out_sd = {}
27
+ double_q, double_k, double_v = {}, {}, {}
28
+ double_tq, double_tk, double_tv = {}, {}, {}
29
+ single_q, single_k, single_v, single_mlp = {}, {}, {}, {}
30
+
31
+ for k, v in state_dict.items():
32
+ if k.startswith("transformer_blocks."):
33
+ idx = k.split(".")[1]
34
+ rest = ".".join(k.split(".")[2:])
35
+ prefix = "double_blocks.{}.".format(idx)
36
+
37
+ if rest.startswith("norm1.linear."):
38
+ out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v
39
+ elif rest.startswith("norm1_context.linear."):
40
+ out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v
41
+ elif rest.startswith("attn.to_q."):
42
+ double_q[idx + "." + rest.split(".")[-1]] = v
43
+ elif rest.startswith("attn.to_k."):
44
+ double_k[idx + "." + rest.split(".")[-1]] = v
45
+ elif rest.startswith("attn.to_v."):
46
+ double_v[idx + "." + rest.split(".")[-1]] = v
47
+ elif rest == "attn.norm_q.weight":
48
+ out_sd[prefix + "img_attn.norm.query_norm.weight"] = v
49
+ elif rest == "attn.norm_k.weight":
50
+ out_sd[prefix + "img_attn.norm.key_norm.weight"] = v
51
+ elif rest.startswith("attn.to_out.0."):
52
+ out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v
53
+ elif rest.startswith("attn.add_q_proj."):
54
+ double_tq[idx + "." + rest.split(".")[-1]] = v
55
+ elif rest.startswith("attn.add_k_proj."):
56
+ double_tk[idx + "." + rest.split(".")[-1]] = v
57
+ elif rest.startswith("attn.add_v_proj."):
58
+ double_tv[idx + "." + rest.split(".")[-1]] = v
59
+ elif rest == "attn.norm_added_q.weight":
60
+ out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v
61
+ elif rest == "attn.norm_added_k.weight":
62
+ out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v
63
+ elif rest.startswith("attn.to_add_out."):
64
+ out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v
65
+ elif rest.startswith("ff.net.0.proj."):
66
+ out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v
67
+ elif rest.startswith("ff.net.2."):
68
+ out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v
69
+ elif rest.startswith("ff_context.net.0.proj."):
70
+ out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v
71
+ elif rest.startswith("ff_context.net.2."):
72
+ out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v
73
+ else:
74
+ out_sd["double_blocks.{}.{}".format(idx, rest)] = v
75
+
76
+ elif k.startswith("single_transformer_blocks."):
77
+ idx = k.split(".")[1]
78
+ rest = ".".join(k.split(".")[2:])
79
+ prefix = "single_blocks.{}.".format(idx)
80
+
81
+ if rest.startswith("norm.linear."):
82
+ out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v
83
+ elif rest.startswith("attn.to_q."):
84
+ single_q[idx + "." + rest.split(".")[-1]] = v
85
+ elif rest.startswith("attn.to_k."):
86
+ single_k[idx + "." + rest.split(".")[-1]] = v
87
+ elif rest.startswith("attn.to_v."):
88
+ single_v[idx + "." + rest.split(".")[-1]] = v
89
+ elif rest == "attn.norm_q.weight":
90
+ out_sd[prefix + "norm.query_norm.weight"] = v
91
+ elif rest == "attn.norm_k.weight":
92
+ out_sd[prefix + "norm.key_norm.weight"] = v
93
+ elif rest.startswith("proj_mlp."):
94
+ single_mlp[idx + "." + rest.split(".")[-1]] = v
95
+ elif rest.startswith("proj_out."):
96
+ out_sd[prefix + "linear2." + rest.split(".")[-1]] = v
97
+ else:
98
+ out_sd["single_blocks.{}.{}".format(idx, rest)] = v
99
+
100
+ elif k == "x_embedder.weight" or k == "x_embedder.bias":
101
+ out_sd["img_in." + k.split(".")[-1]] = v
102
+ elif k == "context_embedder.weight" or k == "context_embedder.bias":
103
+ out_sd["txt_in." + k.split(".")[-1]] = v
104
+ elif k.startswith("time_embed.timestep_embedder.linear_1."):
105
+ out_sd["time_in.in_layer." + k.split(".")[-1]] = v
106
+ elif k.startswith("time_embed.timestep_embedder.linear_2."):
107
+ out_sd["time_in.out_layer." + k.split(".")[-1]] = v
108
+ elif k.startswith("norm_out.linear."):
109
+ # HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI
110
+ # LastLayer expects [shift | scale], so swap the two halves.
111
+ half = v.shape[0] // 2
112
+ v = torch.cat([v[half:], v[:half]], dim=0)
113
+ out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v
114
+ elif k == "proj_out.weight" or k == "proj_out.bias":
115
+ out_sd["final_layer.linear." + k.split(".")[-1]] = v
116
+ else:
117
+ out_sd[k] = v
118
+
119
+ for suffix in ["weight", "bias"]:
120
+ for idx in sorted(set(x.split(".")[0] for x in double_q)):
121
+ qk = idx + "." + suffix
122
+ if qk in double_q and qk in double_k and qk in double_v:
123
+ out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0)
124
+ if qk in double_tq and qk in double_tk and qk in double_tv:
125
+ out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0)
126
+
127
+ for idx in sorted(set(x.split(".")[0] for x in single_q)):
128
+ qk = idx + "." + suffix
129
+ if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp:
130
+ out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0)
131
+
132
+ return out_sd
133
+
134
+
135
+ def main():
136
+ parser = argparse.ArgumentParser(
137
+ description="Convert LongCat-Image weights from Diffusers to ComfyUI format"
138
+ )
139
+ parser.add_argument("input", help="Path to Diffusers-format safetensors file")
140
+ parser.add_argument("output", help="Path to write ComfyUI-format safetensors file")
141
+ args = parser.parse_args()
142
+
143
+ logger.info(f"Loading {args.input}...")
144
+ sd = load_file(args.input)
145
+
146
+ logger.info(f"Converting {len(sd)} keys...")
147
+ converted = convert_longcat_image(sd)
148
+
149
+ logger.info(f"Saving {len(converted)} keys to {args.output}...")
150
+ save_file(converted, args.output)
151
+
152
+ logger.info("Done.")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
download_original.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ curl -OL https://huggingface.co/meituan-longcat/LongCat-Image/resolve/main/transformer/diffusion_pytorch_model.safetensors
split_files/diffusion_models/longcat_image_bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c83c314a3d879d43e5700072033256000f46a56900ae48b209a77ac1921488b
3
+ size 12541383144