|
|
import os |
|
|
import re |
|
|
from safetensors.torch import load_file, save_file |
|
|
|
|
|
def convert_key(key): |
|
|
|
|
|
new_key = key.replace("lora_unet__", "") |
|
|
|
|
|
|
|
|
parts = new_key.split(".") |
|
|
block_name = parts[0] |
|
|
rest = parts[1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_name = re.sub(r'layers_(\d+)_', r'layers.\1.', block_name) |
|
|
|
|
|
|
|
|
block_name = re.sub(r'context_refiner_(\d+)_', r'context_refiner.\1.', block_name) |
|
|
|
|
|
|
|
|
block_name = re.sub(r'noise_refiner_(\d+)_', r'noise_refiner.\1.', block_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for t in ["to_k", "to_q", "to_v"]: |
|
|
if f"_{t}" in block_name: |
|
|
block_name = block_name.replace(f"_{t}", f".{t}") |
|
|
|
|
|
|
|
|
|
|
|
if "_to_out" in block_name: |
|
|
block_name = block_name.replace("_to_out", ".to_out") |
|
|
|
|
|
block_name = re.sub(r'\.to_out_(\d+)', r'.to_out.\1', block_name) |
|
|
|
|
|
|
|
|
final_key = "diffusion_model." + ".".join([block_name] + rest) |
|
|
return final_key |
|
|
|
|
|
|
|
|
print("Looking for .safetensors files...") |
|
|
|
|
|
files = [f for f in os.listdir('.') if f.endswith('.safetensors') and 'converted' not in f and 'fixed' not in f] |
|
|
|
|
|
if not files: |
|
|
print("Error: No original .safetensors file found.") |
|
|
print("Make sure the original 'Z-Image-Fun-Lora-Distill-8-Steps.safetensors' is in this folder.") |
|
|
exit() |
|
|
|
|
|
input_file = files[0] |
|
|
print(f"Processing: {input_file}") |
|
|
|
|
|
try: |
|
|
tensors = load_file(input_file) |
|
|
new_tensors = {} |
|
|
|
|
|
print("Converting keys...") |
|
|
|
|
|
first_key = list(tensors.keys())[0] |
|
|
print(f"Preview: {first_key} \n -> {convert_key(first_key)}") |
|
|
|
|
|
for k, v in tensors.items(): |
|
|
new_k = convert_key(k) |
|
|
new_tensors[new_k] = v |
|
|
|
|
|
output_file = input_file.replace(".safetensors", "_v3_fixed.safetensors") |
|
|
save_file(new_tensors, output_file) |
|
|
|
|
|
print(f"\nSUCCESS! Created: {output_file}") |
|
|
print("Move this file to ComfyUI/models/loras/ and DELETE the old converted ones.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred: {e}") |