RealGen-V2 / scripts /convert_realgen_v2.py
Nynxz
Add ComfyUI-ready LoRA, examples, converter, README, LICENSE
ebf1c59 unverified
"""
Convert the RealGen-V2 PEFT LoRA adapter into a ComfyUI-compatible LoRA.
Input: adapter_model.safetensors (+ adapter_config.json in the same folder,
used to read `lora_alpha` so ComfyUI scales the LoRA correctly)
Output: a single safetensors file with diffusers-style keys
(`<path>.lora_down.weight`, `<path>.lora_up.weight`, `<path>.alpha`).
Drop the output file into `ComfyUI/models/loras/` and load it with the stock
`LoraLoader` against a Z-Image model.
Usage:
python convert_realgen_v2.py adapter_model.safetensors realgen_v2.safetensors
python convert_realgen_v2.py adapter_model.safetensors # auto-names
python convert_realgen_v2.py # cwd defaults
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import safetensors.torch as st
PEFT_PREFIX = "base_model.model."
SUFFIX_MAP = (("lora_A", "lora_down"), ("lora_B", "lora_up"))
def convert(src: Path, dst: Path, alpha_override: float | None = None) -> None:
if not src.is_file():
sys.exit(f"error: source not found: {src}")
# Pull lora_alpha from adapter_config.json if present, else fall back.
lora_alpha: float
if alpha_override is not None:
lora_alpha = float(alpha_override)
print(f"using alpha override: {lora_alpha}")
else:
cfg_path = src.parent / "adapter_config.json"
if cfg_path.is_file():
with cfg_path.open() as f:
cfg = json.load(f)
lora_alpha = float(cfg.get("lora_alpha", 1.0))
print(f"read lora_alpha={lora_alpha} from {cfg_path.name}")
else:
lora_alpha = 128.0 # documented default for RealGen-V2
print(
f"warning: no adapter_config.json next to {src.name}; "
f"falling back to lora_alpha={lora_alpha}"
)
sd = st.load_file(str(src))
out: dict[str, torch.Tensor] = {}
modules: set[str] = set()
skipped: list[str] = []
for k, v in sd.items():
nk = k
if nk.startswith(PEFT_PREFIX):
nk = nk[len(PEFT_PREFIX):]
rewritten = False
for peft_part, comfy_part in SUFFIX_MAP:
tag = f".{peft_part}."
if tag in nk and nk.endswith(".weight"):
head, _ = nk.split(tag, 1)
nk = f"{head}.{comfy_part}.weight"
modules.add(head)
rewritten = True
break
if not rewritten and not nk.endswith(".alpha"):
skipped.append(k)
continue
out[nk] = v.contiguous()
if not modules:
sys.exit(
"error: no LoRA tensors detected. "
"Is this actually a PEFT adapter_model.safetensors?"
)
# Inject per-module alpha tensors if the source didn't ship them.
alpha_tensor = torch.tensor(lora_alpha, dtype=torch.float32)
for m in modules:
out.setdefault(f"{m}.alpha", alpha_tensor.clone())
dst.parent.mkdir(parents=True, exist_ok=True)
st.save_file(out, str(dst))
print(f"wrote {dst}")
print(f" tensors: {len(out)} (modules: {len(modules)})")
if skipped:
print(f" skipped {len(skipped)} unrecognized keys, e.g. {skipped[:3]}")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[1])
ap.add_argument(
"src",
nargs="?",
default="adapter_model.safetensors",
type=Path,
help="Input PEFT adapter (default: ./adapter_model.safetensors)",
)
ap.add_argument(
"dst",
nargs="?",
default=None,
type=Path,
help="Output filename (default: <src stem>_comfy.safetensors)",
)
ap.add_argument(
"--alpha",
type=float,
default=None,
help="Override lora_alpha instead of reading adapter_config.json.",
)
args = ap.parse_args()
src: Path = args.src
dst: Path = args.dst or src.with_name(f"{src.stem}_comfy.safetensors")
convert(src, dst, alpha_override=args.alpha)
if __name__ == "__main__":
main()