Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,459 Bytes
41978ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import typing
import torch
from safetensors import safe_open
import lora as comfy_lora
import comfy.utils as comfy_utils
import comfy.model_patcher
import folder_paths
def _get_model_state_dict(model: typing.Any) -> dict:
if hasattr(model, "model_state_dict"):
try:
return model.model_state_dict()
except TypeError:
return model.model_state_dict(None)
return model.state_dict()
def build_newbie_lora_key_map(model) -> dict:
sd = _get_model_state_dict(model)
key_map = {}
for full_key in sd.keys():
if not full_key.endswith(".weight"):
continue
base = full_key[:-len(".weight")]
variants = set()
variants.add(base)
variants.add("base_model.model." + base)
variants.add("transformer." + base)
short = None
if base.startswith("diffusion_model."):
short = base[len("diffusion_model."):]
variants.add(short)
variants.add("base_model.model." + short)
variants.add("transformer." + short)
variants.add("unet.base_model.model." + short)
lyco_names = ["lycoris_" + base.replace(".", "_")]
if short is not None:
lyco_names.append("lycoris_" + short.replace(".", "_"))
for name in lyco_names:
variants.add(name)
for v in variants:
if v not in key_map:
key_map[v] = full_key
return key_map
def load_newbie_lora_state_dict(lora_name: str) -> tuple:
if not lora_name:
raise ValueError("LoRA name is empty.")
lora_path = folder_paths.get_full_path("loras", lora_name)
if lora_path is None:
raise FileNotFoundError(f"LoRA '{lora_name}' not found in models/loras folder.")
if os.path.isdir(lora_path):
raise ValueError(f"'{lora_path}' is a directory. Please select a LoRA file instead of a folder.")
metadata = {}
if lora_path.endswith('.safetensors'):
with safe_open(lora_path, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
sd = comfy_utils.load_torch_file(lora_path)
if not isinstance(sd, dict):
raise ValueError(f"Loaded LoRA '{lora_name}' does not contain a valid state dict.")
return sd, metadata
def apply_newbie_lora_to_model(
model,
lora_name: str,
strength: float,
) -> comfy.model_patcher.ModelPatcher:
if strength == 0.0:
return model
if not isinstance(model, comfy.model_patcher.ModelPatcher):
model = comfy.model_patcher.ModelPatcher(model)
lora_sd, metadata = load_newbie_lora_state_dict(lora_name)
scale = 1.0
if metadata:
lora_rank = float(metadata.get("lora_rank", 0))
lora_alpha = float(metadata.get("lora_alpha", lora_rank))
if lora_rank > 0:
scale = lora_alpha / lora_rank
final_strength = strength * scale
to_load = build_newbie_lora_key_map(model.model)
patches = comfy_lora.load_lora(lora_sd, to_load, log_missing=True)
if not patches:
print(f"Warning: No valid patches found in LoRA '{lora_name}'.")
return model
patched_model = model.clone()
patched_model.add_patches(patches, strength_patch=float(final_strength), strength_model=1.0)
return patched_model |