File size: 3,459 Bytes
41978ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import typing
import torch
from safetensors import safe_open

import lora as comfy_lora
import comfy.utils as comfy_utils
import comfy.model_patcher
import folder_paths

def _get_model_state_dict(model: typing.Any) -> dict:
    if hasattr(model, "model_state_dict"):
        try:
            return model.model_state_dict()
        except TypeError:
            return model.model_state_dict(None)
    return model.state_dict()


def build_newbie_lora_key_map(model) -> dict:
    sd = _get_model_state_dict(model)
    key_map = {}

    for full_key in sd.keys():
        if not full_key.endswith(".weight"):
            continue

        base = full_key[:-len(".weight")]
        variants = set()
        variants.add(base)
        variants.add("base_model.model." + base)
        variants.add("transformer." + base)

        short = None
        if base.startswith("diffusion_model."):
            short = base[len("diffusion_model."):]
            variants.add(short)
            variants.add("base_model.model." + short)
            variants.add("transformer." + short)
            variants.add("unet.base_model.model." + short)
        
        lyco_names = ["lycoris_" + base.replace(".", "_")]
        if short is not None:
            lyco_names.append("lycoris_" + short.replace(".", "_"))

        for name in lyco_names:
            variants.add(name)

        for v in variants:
            if v not in key_map:
                key_map[v] = full_key

    return key_map


def load_newbie_lora_state_dict(lora_name: str) -> tuple:
    if not lora_name:
        raise ValueError("LoRA name is empty.")
    lora_path = folder_paths.get_full_path("loras", lora_name)
    if lora_path is None:
        raise FileNotFoundError(f"LoRA '{lora_name}' not found in models/loras folder.")
    if os.path.isdir(lora_path):
        raise ValueError(f"'{lora_path}' is a directory. Please select a LoRA file instead of a folder.")

    metadata = {}
    if lora_path.endswith('.safetensors'):
        with safe_open(lora_path, framework="pt", device="cpu") as f:
            metadata = f.metadata() or {}

    sd = comfy_utils.load_torch_file(lora_path)
    if not isinstance(sd, dict):
        raise ValueError(f"Loaded LoRA '{lora_name}' does not contain a valid state dict.")
    return sd, metadata


def apply_newbie_lora_to_model(

    model,

    lora_name: str,

    strength: float,

) -> comfy.model_patcher.ModelPatcher:
    if strength == 0.0:
        return model
    
    if not isinstance(model, comfy.model_patcher.ModelPatcher):
        model = comfy.model_patcher.ModelPatcher(model)
        
    lora_sd, metadata = load_newbie_lora_state_dict(lora_name)

    scale = 1.0
    if metadata:
        lora_rank = float(metadata.get("lora_rank", 0))
        lora_alpha = float(metadata.get("lora_alpha", lora_rank))
        if lora_rank > 0:
            scale = lora_alpha / lora_rank

    final_strength = strength * scale
    to_load = build_newbie_lora_key_map(model.model)
    patches = comfy_lora.load_lora(lora_sd, to_load, log_missing=True)
    
    if not patches:
        print(f"Warning: No valid patches found in LoRA '{lora_name}'.")
        return model
        
    patched_model = model.clone()
    patched_model.add_patches(patches, strength_patch=float(final_strength), strength_model=1.0)
    
    return patched_model