File size: 5,053 Bytes
0ae355f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import comfy.utils
import folder_paths
import torch
import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
activation_func = sd.get('activation_func', 'linear')
is_layer_norm = sd.get('is_layer_norm', False)
use_dropout = sd.get('use_dropout', False)
activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False)
valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
}
if activation_func not in valid_activation:
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None
out = {}
for d in sd:
try:
dim = int(d)
except:
continue
output = []
for index in [0, 1]:
attn_weights = sd[dim][index]
keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys)
linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = []
i = 0
while i < len(linears):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)
class hypernetwork_patch:
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
k = k + hn[0](k) * self.strength
v = v + hn[1](v) * self.strength
return q, k, v
def to(self, device):
for d in self.hypernet.keys():
self.hypernet[d] = self.hypernet[d].to(device)
return self
return hypernetwork_patch(out, strength)
class HypernetworkLoader(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
category="loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
@classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return IO.NodeOutput(model_hypernetwork)
load_hypernetwork = execute # TODO: remove
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()
|