comfyui_nunchaku_lora_patch / patch_comfyui_nunchaku_lora.py
Obzy98's picture
Update patch_comfyui_nunchaku_lora.py
e19bb66 verified
raw
history blame
4.35 kB
import safetensors.torch
from safetensors import safe_open
import torch
import os
import tkinter as tk
from tkinter import filedialog
def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
final_layer_linear_down = None
final_layer_linear_up = None
adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
linear_down_key = f"{prefix}_linear.lora_down.weight"
linear_up_key = f"{prefix}_linear.lora_up.weight"
if verbose:
print(f"\nπŸ” Checking for final_layer keys with prefix: '{prefix}'")
print(f" Linear down: {linear_down_key}")
print(f" Linear up: {linear_up_key}")
if linear_down_key in state_dict:
final_layer_linear_down = state_dict[linear_down_key]
if linear_up_key in state_dict:
final_layer_linear_up = state_dict[linear_up_key]
has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
if verbose:
print(f" βœ… Has final_layer.linear: {has_linear}")
print(f" βœ… Has final_layer.adaLN_modulation_1: {has_adaLN}")
if has_linear and not has_adaLN:
dummy_down = torch.zeros_like(final_layer_linear_down)
dummy_up = torch.zeros_like(final_layer_linear_up)
state_dict[adaLN_down_key] = dummy_down
state_dict[adaLN_up_key] = dummy_up
if verbose:
print(f"βœ… Added dummy adaLN weights:")
print(f" {adaLN_down_key} (shape: {dummy_down.shape})")
print(f" {adaLN_up_key} (shape: {dummy_up.shape})")
else:
if verbose:
print("βœ… No patch needed β€” adaLN weights already present or no final_layer.linear found.")
return state_dict
def main():
print("πŸ”„ Universal final_layer.adaLN LoRA patcher (.safetensors)")
# GUI for file/folder selection
root = tk.Tk()
root.withdraw()
input_path = filedialog.askopenfilename(
title="Select LoRA .safetensors file",
filetypes=[("Safetensors files", "*.safetensors")]
)
if not input_path:
print("❌ No file selected. Exiting.")
return
output_dir = filedialog.askdirectory(
title="Select folder to save patched file"
)
if not output_dir:
print("❌ No folder selected. Exiting.")
return
# Generate output filename
base_name = os.path.basename(input_path)
name, ext = os.path.splitext(base_name)
output_filename = f"{name}-Patched{ext}"
output_path = os.path.join(output_dir, output_filename)
# Load
state_dict = {}
with safe_open(input_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
print(f"\nβœ… Loaded {len(state_dict)} tensors from: {input_path}")
final_keys = [k for k in state_dict if "final_layer" in k]
if final_keys:
print("\nπŸ”‘ Found these final_layer-related keys:")
for k in final_keys:
print(f" {k}")
else:
print("\n⚠️ No keys with 'final_layer' found β€” will try patch anyway.")
prefixes = [
"lora_unet_final_layer",
"final_layer",
"base_model.model.final_layer"
]
patched = False
for prefix in prefixes:
before = len(state_dict)
state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix)
after = len(state_dict)
if after > before:
patched = True
break
if not patched:
print("\nℹ️ No patch applied β€” either adaLN already exists or no final_layer.linear found.")
# Save
safetensors.torch.save_file(state_dict, output_path)
print(f"\nβœ… Patched file saved to: {output_path}")
print(f" Total tensors now: {len(state_dict)}")
# Verify
print("\nπŸ” Verifying patched keys:")
with safe_open(output_path, framework="pt", device="cpu") as f:
keys = list(f.keys())
for k in keys:
if "final_layer" in k:
print(f" {k}")
has_adaLN_after = any("adaLN_modulation_1" in k for k in keys)
print(f"βœ… Contains adaLN after patch: {has_adaLN_after}")
if __name__ == "__main__":
main()