File size: 4,152 Bytes
ebf1c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Convert the RealGen-V2 PEFT LoRA adapter into a ComfyUI-compatible LoRA.

Input:  adapter_model.safetensors  (+ adapter_config.json in the same folder,
        used to read `lora_alpha` so ComfyUI scales the LoRA correctly)
Output: a single safetensors file with diffusers-style keys
        (`<path>.lora_down.weight`, `<path>.lora_up.weight`, `<path>.alpha`).

Drop the output file into `ComfyUI/models/loras/` and load it with the stock
`LoraLoader` against a Z-Image model.

Usage:
  python convert_realgen_v2.py adapter_model.safetensors realgen_v2.safetensors
  python convert_realgen_v2.py adapter_model.safetensors            # auto-names
  python convert_realgen_v2.py                                      # cwd defaults
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import torch
import safetensors.torch as st


PEFT_PREFIX = "base_model.model."
SUFFIX_MAP = (("lora_A", "lora_down"), ("lora_B", "lora_up"))


def convert(src: Path, dst: Path, alpha_override: float | None = None) -> None:
    if not src.is_file():
        sys.exit(f"error: source not found: {src}")

    # Pull lora_alpha from adapter_config.json if present, else fall back.
    lora_alpha: float
    if alpha_override is not None:
        lora_alpha = float(alpha_override)
        print(f"using alpha override: {lora_alpha}")
    else:
        cfg_path = src.parent / "adapter_config.json"
        if cfg_path.is_file():
            with cfg_path.open() as f:
                cfg = json.load(f)
            lora_alpha = float(cfg.get("lora_alpha", 1.0))
            print(f"read lora_alpha={lora_alpha} from {cfg_path.name}")
        else:
            lora_alpha = 128.0  # documented default for RealGen-V2
            print(
                f"warning: no adapter_config.json next to {src.name}; "
                f"falling back to lora_alpha={lora_alpha}"
            )

    sd = st.load_file(str(src))
    out: dict[str, torch.Tensor] = {}
    modules: set[str] = set()
    skipped: list[str] = []

    for k, v in sd.items():
        nk = k
        if nk.startswith(PEFT_PREFIX):
            nk = nk[len(PEFT_PREFIX):]

        rewritten = False
        for peft_part, comfy_part in SUFFIX_MAP:
            tag = f".{peft_part}."
            if tag in nk and nk.endswith(".weight"):
                head, _ = nk.split(tag, 1)
                nk = f"{head}.{comfy_part}.weight"
                modules.add(head)
                rewritten = True
                break

        if not rewritten and not nk.endswith(".alpha"):
            skipped.append(k)
            continue

        out[nk] = v.contiguous()

    if not modules:
        sys.exit(
            "error: no LoRA tensors detected. "
            "Is this actually a PEFT adapter_model.safetensors?"
        )

    # Inject per-module alpha tensors if the source didn't ship them.
    alpha_tensor = torch.tensor(lora_alpha, dtype=torch.float32)
    for m in modules:
        out.setdefault(f"{m}.alpha", alpha_tensor.clone())

    dst.parent.mkdir(parents=True, exist_ok=True)
    st.save_file(out, str(dst))

    print(f"wrote {dst}")
    print(f"  tensors: {len(out)}  (modules: {len(modules)})")
    if skipped:
        print(f"  skipped {len(skipped)} unrecognized keys, e.g. {skipped[:3]}")


def main() -> None:
    ap = argparse.ArgumentParser(description=__doc__.splitlines()[1])
    ap.add_argument(
        "src",
        nargs="?",
        default="adapter_model.safetensors",
        type=Path,
        help="Input PEFT adapter (default: ./adapter_model.safetensors)",
    )
    ap.add_argument(
        "dst",
        nargs="?",
        default=None,
        type=Path,
        help="Output filename (default: <src stem>_comfy.safetensors)",
    )
    ap.add_argument(
        "--alpha",
        type=float,
        default=None,
        help="Override lora_alpha instead of reading adapter_config.json.",
    )
    args = ap.parse_args()

    src: Path = args.src
    dst: Path = args.dst or src.with_name(f"{src.stem}_comfy.safetensors")
    convert(src, dst, alpha_override=args.alpha)


if __name__ == "__main__":
    main()