LongCat-Image-Edit_ComfyUI_repackaged / convert_original_to_comfy.py
TalmajM's picture
Upload folder using huggingface_hub
312bb02 verified
#!/usr/bin/env python3
"""
Convert LongCat-Image transformer weights from HuggingFace Diffusers format
to ComfyUI format.
Usage:
python conversion.py input.safetensors output.safetensors
The input file is the Diffusers-format transformer, typically:
meituan-longcat/LongCat-Image/transformer/diffusion_pytorch_model.safetensors
The output file will contain ComfyUI-format keys with fused QKV tensors,
ready for zero-copy loading via UNETLoader.
"""
import argparse
import torch
import logging
from safetensors.torch import load_file, save_file
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def convert_longcat_image(state_dict):
out_sd = {}
double_q, double_k, double_v = {}, {}, {}
double_tq, double_tk, double_tv = {}, {}, {}
single_q, single_k, single_v, single_mlp = {}, {}, {}, {}
for k, v in state_dict.items():
if k.startswith("transformer_blocks."):
idx = k.split(".")[1]
rest = ".".join(k.split(".")[2:])
prefix = "double_blocks.{}.".format(idx)
if rest.startswith("norm1.linear."):
out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v
elif rest.startswith("norm1_context.linear."):
out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_q."):
double_q[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_k."):
double_k[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_v."):
double_v[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_q.weight":
out_sd[prefix + "img_attn.norm.query_norm.weight"] = v
elif rest == "attn.norm_k.weight":
out_sd[prefix + "img_attn.norm.key_norm.weight"] = v
elif rest.startswith("attn.to_out.0."):
out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_q_proj."):
double_tq[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_k_proj."):
double_tk[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_v_proj."):
double_tv[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_added_q.weight":
out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v
elif rest == "attn.norm_added_k.weight":
out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v
elif rest.startswith("attn.to_add_out."):
out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v
elif rest.startswith("ff.net.0.proj."):
out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v
elif rest.startswith("ff.net.2."):
out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v
elif rest.startswith("ff_context.net.0.proj."):
out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v
elif rest.startswith("ff_context.net.2."):
out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v
else:
out_sd["double_blocks.{}.{}".format(idx, rest)] = v
elif k.startswith("single_transformer_blocks."):
idx = k.split(".")[1]
rest = ".".join(k.split(".")[2:])
prefix = "single_blocks.{}.".format(idx)
if rest.startswith("norm.linear."):
out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_q."):
single_q[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_k."):
single_k[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_v."):
single_v[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_q.weight":
out_sd[prefix + "norm.query_norm.weight"] = v
elif rest == "attn.norm_k.weight":
out_sd[prefix + "norm.key_norm.weight"] = v
elif rest.startswith("proj_mlp."):
single_mlp[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("proj_out."):
out_sd[prefix + "linear2." + rest.split(".")[-1]] = v
else:
out_sd["single_blocks.{}.{}".format(idx, rest)] = v
elif k == "x_embedder.weight" or k == "x_embedder.bias":
out_sd["img_in." + k.split(".")[-1]] = v
elif k == "context_embedder.weight" or k == "context_embedder.bias":
out_sd["txt_in." + k.split(".")[-1]] = v
elif k.startswith("time_embed.timestep_embedder.linear_1."):
out_sd["time_in.in_layer." + k.split(".")[-1]] = v
elif k.startswith("time_embed.timestep_embedder.linear_2."):
out_sd["time_in.out_layer." + k.split(".")[-1]] = v
elif k.startswith("norm_out.linear."):
# HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI
# LastLayer expects [shift | scale], so swap the two halves.
half = v.shape[0] // 2
v = torch.cat([v[half:], v[:half]], dim=0)
out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v
elif k == "proj_out.weight" or k == "proj_out.bias":
out_sd["final_layer.linear." + k.split(".")[-1]] = v
else:
out_sd[k] = v
for suffix in ["weight", "bias"]:
for idx in sorted(set(x.split(".")[0] for x in double_q)):
qk = idx + "." + suffix
if qk in double_q and qk in double_k and qk in double_v:
out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0)
if qk in double_tq and qk in double_tk and qk in double_tv:
out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0)
for idx in sorted(set(x.split(".")[0] for x in single_q)):
qk = idx + "." + suffix
if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp:
out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0)
return out_sd
def main():
parser = argparse.ArgumentParser(
description="Convert LongCat-Image weights from Diffusers to ComfyUI format"
)
parser.add_argument("input", help="Path to Diffusers-format safetensors file")
parser.add_argument("output", help="Path to write ComfyUI-format safetensors file")
args = parser.parse_args()
logger.info(f"Loading {args.input}...")
sd = load_file(args.input)
logger.info(f"Converting {len(sd)} keys...")
converted = convert_longcat_image(sd)
logger.info(f"Saving {len(converted)} keys to {args.output}...")
save_file(converted, args.output)
logger.info("Done.")
if __name__ == "__main__":
main()