| |
| """ |
| Convert LongCat-Image transformer weights from HuggingFace Diffusers format |
| to ComfyUI format. |
| |
| Usage: |
| python conversion.py input.safetensors output.safetensors |
| |
| The input file is the Diffusers-format transformer, typically: |
| meituan-longcat/LongCat-Image/transformer/diffusion_pytorch_model.safetensors |
| |
| The output file will contain ComfyUI-format keys with fused QKV tensors, |
| ready for zero-copy loading via UNETLoader. |
| """ |
|
|
| import argparse |
| import torch |
| import logging |
| from safetensors.torch import load_file, save_file |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def convert_longcat_image(state_dict): |
| out_sd = {} |
| double_q, double_k, double_v = {}, {}, {} |
| double_tq, double_tk, double_tv = {}, {}, {} |
| single_q, single_k, single_v, single_mlp = {}, {}, {}, {} |
|
|
| for k, v in state_dict.items(): |
| if k.startswith("transformer_blocks."): |
| idx = k.split(".")[1] |
| rest = ".".join(k.split(".")[2:]) |
| prefix = "double_blocks.{}.".format(idx) |
|
|
| if rest.startswith("norm1.linear."): |
| out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v |
| elif rest.startswith("norm1_context.linear."): |
| out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_q."): |
| double_q[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_k."): |
| double_k[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_v."): |
| double_v[idx + "." + rest.split(".")[-1]] = v |
| elif rest == "attn.norm_q.weight": |
| out_sd[prefix + "img_attn.norm.query_norm.weight"] = v |
| elif rest == "attn.norm_k.weight": |
| out_sd[prefix + "img_attn.norm.key_norm.weight"] = v |
| elif rest.startswith("attn.to_out.0."): |
| out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.add_q_proj."): |
| double_tq[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.add_k_proj."): |
| double_tk[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.add_v_proj."): |
| double_tv[idx + "." + rest.split(".")[-1]] = v |
| elif rest == "attn.norm_added_q.weight": |
| out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v |
| elif rest == "attn.norm_added_k.weight": |
| out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v |
| elif rest.startswith("attn.to_add_out."): |
| out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v |
| elif rest.startswith("ff.net.0.proj."): |
| out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v |
| elif rest.startswith("ff.net.2."): |
| out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v |
| elif rest.startswith("ff_context.net.0.proj."): |
| out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v |
| elif rest.startswith("ff_context.net.2."): |
| out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v |
| else: |
| out_sd["double_blocks.{}.{}".format(idx, rest)] = v |
|
|
| elif k.startswith("single_transformer_blocks."): |
| idx = k.split(".")[1] |
| rest = ".".join(k.split(".")[2:]) |
| prefix = "single_blocks.{}.".format(idx) |
|
|
| if rest.startswith("norm.linear."): |
| out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_q."): |
| single_q[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_k."): |
| single_k[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("attn.to_v."): |
| single_v[idx + "." + rest.split(".")[-1]] = v |
| elif rest == "attn.norm_q.weight": |
| out_sd[prefix + "norm.query_norm.weight"] = v |
| elif rest == "attn.norm_k.weight": |
| out_sd[prefix + "norm.key_norm.weight"] = v |
| elif rest.startswith("proj_mlp."): |
| single_mlp[idx + "." + rest.split(".")[-1]] = v |
| elif rest.startswith("proj_out."): |
| out_sd[prefix + "linear2." + rest.split(".")[-1]] = v |
| else: |
| out_sd["single_blocks.{}.{}".format(idx, rest)] = v |
|
|
| elif k == "x_embedder.weight" or k == "x_embedder.bias": |
| out_sd["img_in." + k.split(".")[-1]] = v |
| elif k == "context_embedder.weight" or k == "context_embedder.bias": |
| out_sd["txt_in." + k.split(".")[-1]] = v |
| elif k.startswith("time_embed.timestep_embedder.linear_1."): |
| out_sd["time_in.in_layer." + k.split(".")[-1]] = v |
| elif k.startswith("time_embed.timestep_embedder.linear_2."): |
| out_sd["time_in.out_layer." + k.split(".")[-1]] = v |
| elif k.startswith("norm_out.linear."): |
| |
| |
| half = v.shape[0] // 2 |
| v = torch.cat([v[half:], v[:half]], dim=0) |
| out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v |
| elif k == "proj_out.weight" or k == "proj_out.bias": |
| out_sd["final_layer.linear." + k.split(".")[-1]] = v |
| else: |
| out_sd[k] = v |
|
|
| for suffix in ["weight", "bias"]: |
| for idx in sorted(set(x.split(".")[0] for x in double_q)): |
| qk = idx + "." + suffix |
| if qk in double_q and qk in double_k and qk in double_v: |
| out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0) |
| if qk in double_tq and qk in double_tk and qk in double_tv: |
| out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0) |
|
|
| for idx in sorted(set(x.split(".")[0] for x in single_q)): |
| qk = idx + "." + suffix |
| if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp: |
| out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0) |
|
|
| return out_sd |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Convert LongCat-Image weights from Diffusers to ComfyUI format" |
| ) |
| parser.add_argument("input", help="Path to Diffusers-format safetensors file") |
| parser.add_argument("output", help="Path to write ComfyUI-format safetensors file") |
| args = parser.parse_args() |
|
|
| logger.info(f"Loading {args.input}...") |
| sd = load_file(args.input) |
|
|
| logger.info(f"Converting {len(sd)} keys...") |
| converted = convert_longcat_image(sd) |
|
|
| logger.info(f"Saving {len(converted)} keys to {args.output}...") |
| save_file(converted, args.output) |
|
|
| logger.info("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|