import os import typing import torch from safetensors import safe_open import lora as comfy_lora import comfy.utils as comfy_utils import comfy.model_patcher import folder_paths def _get_model_state_dict(model: typing.Any) -> dict: if hasattr(model, "model_state_dict"): try: return model.model_state_dict() except TypeError: return model.model_state_dict(None) return model.state_dict() def build_newbie_lora_key_map(model) -> dict: sd = _get_model_state_dict(model) key_map = {} for full_key in sd.keys(): if not full_key.endswith(".weight"): continue base = full_key[:-len(".weight")] variants = set() variants.add(base) variants.add("base_model.model." + base) variants.add("transformer." + base) short = None if base.startswith("diffusion_model."): short = base[len("diffusion_model."):] variants.add(short) variants.add("base_model.model." + short) variants.add("transformer." + short) variants.add("unet.base_model.model." + short) lyco_names = ["lycoris_" + base.replace(".", "_")] if short is not None: lyco_names.append("lycoris_" + short.replace(".", "_")) for name in lyco_names: variants.add(name) for v in variants: if v not in key_map: key_map[v] = full_key return key_map def load_newbie_lora_state_dict(lora_name: str) -> tuple: if not lora_name: raise ValueError("LoRA name is empty.") lora_path = folder_paths.get_full_path("loras", lora_name) if lora_path is None: raise FileNotFoundError(f"LoRA '{lora_name}' not found in models/loras folder.") if os.path.isdir(lora_path): raise ValueError(f"'{lora_path}' is a directory. Please select a LoRA file instead of a folder.") metadata = {} if lora_path.endswith('.safetensors'): with safe_open(lora_path, framework="pt", device="cpu") as f: metadata = f.metadata() or {} sd = comfy_utils.load_torch_file(lora_path) if not isinstance(sd, dict): raise ValueError(f"Loaded LoRA '{lora_name}' does not contain a valid state dict.") return sd, metadata def apply_newbie_lora_to_model( model, lora_name: str, strength: float, ) -> comfy.model_patcher.ModelPatcher: if strength == 0.0: return model if not isinstance(model, comfy.model_patcher.ModelPatcher): model = comfy.model_patcher.ModelPatcher(model) lora_sd, metadata = load_newbie_lora_state_dict(lora_name) scale = 1.0 if metadata: lora_rank = float(metadata.get("lora_rank", 0)) lora_alpha = float(metadata.get("lora_alpha", lora_rank)) if lora_rank > 0: scale = lora_alpha / lora_rank final_strength = strength * scale to_load = build_newbie_lora_key_map(model.model) patches = comfy_lora.load_lora(lora_sd, to_load, log_missing=True) if not patches: print(f"Warning: No valid patches found in LoRA '{lora_name}'.") return model patched_model = model.clone() patched_model.add_patches(patches, strength_patch=float(final_strength), strength_model=1.0) return patched_model