| import safetensors.torch |
| from safetensors import safe_open |
| import torch |
| import os |
| import tkinter as tk |
| from tkinter import filedialog |
|
|
| def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True): |
| final_layer_linear_down = None |
| final_layer_linear_up = None |
|
|
| adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight" |
| adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight" |
| linear_down_key = f"{prefix}_linear.lora_down.weight" |
| linear_up_key = f"{prefix}_linear.lora_up.weight" |
|
|
| if verbose: |
| print(f"\nπ Checking for final_layer keys with prefix: '{prefix}'") |
| print(f" Linear down: {linear_down_key}") |
| print(f" Linear up: {linear_up_key}") |
|
|
| if linear_down_key in state_dict: |
| final_layer_linear_down = state_dict[linear_down_key] |
| if linear_up_key in state_dict: |
| final_layer_linear_up = state_dict[linear_up_key] |
|
|
| has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict |
| has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None |
|
|
| if verbose: |
| print(f" β
Has final_layer.linear: {has_linear}") |
| print(f" β
Has final_layer.adaLN_modulation_1: {has_adaLN}") |
|
|
| if has_linear and not has_adaLN: |
| dummy_down = torch.zeros_like(final_layer_linear_down) |
| dummy_up = torch.zeros_like(final_layer_linear_up) |
| state_dict[adaLN_down_key] = dummy_down |
| state_dict[adaLN_up_key] = dummy_up |
|
|
| if verbose: |
| print(f"β
Added dummy adaLN weights:") |
| print(f" {adaLN_down_key} (shape: {dummy_down.shape})") |
| print(f" {adaLN_up_key} (shape: {dummy_up.shape})") |
| else: |
| if verbose: |
| print("β
No patch needed β adaLN weights already present or no final_layer.linear found.") |
|
|
| return state_dict |
|
|
| def main(): |
| print("π Universal final_layer.adaLN LoRA patcher (.safetensors)") |
|
|
| |
| root = tk.Tk() |
| root.withdraw() |
|
|
| input_path = filedialog.askopenfilename( |
| title="Select LoRA .safetensors file", |
| filetypes=[("Safetensors files", "*.safetensors")] |
| ) |
| if not input_path: |
| print("β No file selected. Exiting.") |
| return |
|
|
| output_dir = filedialog.askdirectory( |
| title="Select folder to save patched file" |
| ) |
| if not output_dir: |
| print("β No folder selected. Exiting.") |
| return |
|
|
| |
| base_name = os.path.basename(input_path) |
| name, ext = os.path.splitext(base_name) |
| output_filename = f"{name}-Patched{ext}" |
| output_path = os.path.join(output_dir, output_filename) |
|
|
| |
| state_dict = {} |
| with safe_open(input_path, framework="pt", device="cpu") as f: |
| for k in f.keys(): |
| state_dict[k] = f.get_tensor(k) |
|
|
| print(f"\nβ
Loaded {len(state_dict)} tensors from: {input_path}") |
|
|
| final_keys = [k for k in state_dict if "final_layer" in k] |
| if final_keys: |
| print("\nπ Found these final_layer-related keys:") |
| for k in final_keys: |
| print(f" {k}") |
| else: |
| print("\nβ οΈ No keys with 'final_layer' found β will try patch anyway.") |
|
|
| prefixes = [ |
| "lora_unet_final_layer", |
| "final_layer", |
| "base_model.model.final_layer" |
| ] |
| patched = False |
|
|
| for prefix in prefixes: |
| before = len(state_dict) |
| state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix) |
| after = len(state_dict) |
| if after > before: |
| patched = True |
| break |
|
|
| if not patched: |
| print("\nβΉοΈ No patch applied β either adaLN already exists or no final_layer.linear found.") |
|
|
| |
| safetensors.torch.save_file(state_dict, output_path) |
| print(f"\nβ
Patched file saved to: {output_path}") |
| print(f" Total tensors now: {len(state_dict)}") |
|
|
| |
| print("\nπ Verifying patched keys:") |
| with safe_open(output_path, framework="pt", device="cpu") as f: |
| keys = list(f.keys()) |
| for k in keys: |
| if "final_layer" in k: |
| print(f" {k}") |
| has_adaLN_after = any("adaLN_modulation_1" in k for k in keys) |
| print(f"β
Contains adaLN after patch: {has_adaLN_after}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|