Danrisi commited on
Commit
97b30ef
·
verified ·
1 Parent(s): c5ac7ed

Upload lora.py

Browse files
Files changed (1) hide show
  1. misc/comfy/lora.py +437 -0
misc/comfy/lora.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ from __future__ import annotations
20
+ import comfy.utils
21
+ import comfy.model_management
22
+ import comfy.model_base
23
+ import comfy.weight_adapter as weight_adapter
24
+ import logging
25
+ import torch
26
+
27
+ LORA_CLIP_MAP = {
28
+ "mlp.fc1": "mlp_fc1",
29
+ "mlp.fc2": "mlp_fc2",
30
+ "self_attn.k_proj": "self_attn_k_proj",
31
+ "self_attn.q_proj": "self_attn_q_proj",
32
+ "self_attn.v_proj": "self_attn_v_proj",
33
+ "self_attn.out_proj": "self_attn_out_proj",
34
+ }
35
+
36
+
37
+ def load_lora(lora, to_load, log_missing=True):
38
+ patch_dict = {}
39
+ loaded_keys = set()
40
+ for x in to_load:
41
+ alpha_name = "{}.alpha".format(x)
42
+ alpha = None
43
+ if alpha_name in lora.keys():
44
+ alpha = lora[alpha_name].item()
45
+ loaded_keys.add(alpha_name)
46
+
47
+ dora_scale_name = "{}.dora_scale".format(x)
48
+ dora_scale = None
49
+ if dora_scale_name in lora.keys():
50
+ dora_scale = lora[dora_scale_name]
51
+ loaded_keys.add(dora_scale_name)
52
+
53
+ for adapter_cls in weight_adapter.adapters:
54
+ adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
55
+ if adapter is not None:
56
+ patch_dict[to_load[x]] = adapter
57
+ loaded_keys.update(adapter.loaded_keys)
58
+ continue
59
+
60
+ w_norm_name = "{}.w_norm".format(x)
61
+ b_norm_name = "{}.b_norm".format(x)
62
+ w_norm = lora.get(w_norm_name, None)
63
+ b_norm = lora.get(b_norm_name, None)
64
+
65
+ if w_norm is not None:
66
+ loaded_keys.add(w_norm_name)
67
+ patch_dict[to_load[x]] = ("diff", (w_norm,))
68
+ if b_norm is not None:
69
+ loaded_keys.add(b_norm_name)
70
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
71
+
72
+ diff_name = "{}.diff".format(x)
73
+ diff_weight = lora.get(diff_name, None)
74
+ if diff_weight is not None:
75
+ patch_dict[to_load[x]] = ("diff", (diff_weight,))
76
+ loaded_keys.add(diff_name)
77
+
78
+ diff_bias_name = "{}.diff_b".format(x)
79
+ diff_bias = lora.get(diff_bias_name, None)
80
+ if diff_bias is not None:
81
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
82
+ loaded_keys.add(diff_bias_name)
83
+
84
+ set_weight_name = "{}.set_weight".format(x)
85
+ set_weight = lora.get(set_weight_name, None)
86
+ if set_weight is not None:
87
+ patch_dict[to_load[x]] = ("set", (set_weight,))
88
+ loaded_keys.add(set_weight_name)
89
+
90
+ if log_missing:
91
+ for x in lora.keys():
92
+ if x not in loaded_keys:
93
+ logging.warning("lora key not loaded: {}".format(x))
94
+
95
+ return patch_dict
96
+
97
+ def model_lora_keys_clip(model, key_map={}):
98
+ sdk = model.state_dict().keys()
99
+ for k in sdk:
100
+ if k.endswith(".weight"):
101
+ key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
102
+
103
+ text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
104
+ clip_l_present = False
105
+ clip_g_present = False
106
+ for b in range(32): #TODO: clean up
107
+ for c in LORA_CLIP_MAP:
108
+ k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
109
+ if k in sdk:
110
+ lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
111
+ key_map[lora_key] = k
112
+ lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
113
+ key_map[lora_key] = k
114
+ lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
115
+ key_map[lora_key] = k
116
+
117
+ k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
118
+ if k in sdk:
119
+ lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
120
+ key_map[lora_key] = k
121
+ lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
122
+ key_map[lora_key] = k
123
+ clip_l_present = True
124
+ lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
125
+ key_map[lora_key] = k
126
+
127
+ k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
128
+ if k in sdk:
129
+ clip_g_present = True
130
+ if clip_l_present:
131
+ lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
132
+ key_map[lora_key] = k
133
+ lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
134
+ key_map[lora_key] = k
135
+ else:
136
+ lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
137
+ key_map[lora_key] = k
138
+ lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
139
+ key_map[lora_key] = k
140
+ lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
141
+ key_map[lora_key] = k
142
+
143
+ for k in sdk:
144
+ if k.endswith(".weight"):
145
+ if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
146
+ l_key = k[len("t5xxl.transformer."):-len(".weight")]
147
+ t5_index = 1
148
+ if clip_g_present:
149
+ t5_index += 1
150
+ if clip_l_present:
151
+ t5_index += 1
152
+ if t5_index == 2:
153
+ key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
154
+ t5_index += 1
155
+
156
+ key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
157
+ elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
158
+ l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
159
+ lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
160
+ key_map[lora_key] = k
161
+
162
+
163
+ k = "clip_g.transformer.text_projection.weight"
164
+ if k in sdk:
165
+ key_map["lora_prior_te_text_projection"] = k #cascade lora?
166
+ # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
167
+ key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
168
+
169
+ k = "clip_l.transformer.text_projection.weight"
170
+ if k in sdk:
171
+ key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
172
+
173
+ return key_map
174
+
175
+ def model_lora_keys_unet(model, key_map={}):
176
+ sd = model.state_dict()
177
+ sdk = sd.keys()
178
+
179
+ for k in sdk:
180
+ if k.startswith("diffusion_model."):
181
+ if k.endswith(".weight"):
182
+ key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
183
+ key_map["lora_unet_{}".format(key_lora)] = k
184
+ key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
185
+ else:
186
+ key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
187
+
188
+ diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
189
+ for k in diffusers_keys:
190
+ if k.endswith(".weight"):
191
+ unet_key = "diffusion_model.{}".format(diffusers_keys[k])
192
+ key_lora = k[:-len(".weight")].replace(".", "_")
193
+ key_map["lora_unet_{}".format(key_lora)] = unet_key
194
+ key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
195
+
196
+ diffusers_lora_prefix = ["", "unet."]
197
+ for p in diffusers_lora_prefix:
198
+ diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
199
+ if diffusers_lora_key.endswith(".to_out.0"):
200
+ diffusers_lora_key = diffusers_lora_key[:-2]
201
+ key_map[diffusers_lora_key] = unet_key
202
+
203
+ if isinstance(model, comfy.model_base.StableCascade_C):
204
+ for k in sdk:
205
+ if k.startswith("diffusion_model."):
206
+ if k.endswith(".weight"):
207
+ key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
208
+ key_map["lora_prior_unet_{}".format(key_lora)] = k
209
+
210
+ if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
211
+ diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
212
+ for k in diffusers_keys:
213
+ if k.endswith(".weight"):
214
+ to = diffusers_keys[k]
215
+ key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
216
+ key_map[key_lora] = to
217
+
218
+ key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
219
+ key_map[key_lora] = to
220
+
221
+ key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
222
+ key_map[key_lora] = to
223
+
224
+ key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
225
+ key_map[key_lora] = to
226
+
227
+ if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
228
+ diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
229
+ for k in diffusers_keys:
230
+ if k.endswith(".weight"):
231
+ to = diffusers_keys[k]
232
+ key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
233
+ key_map[key_lora] = to
234
+
235
+ if isinstance(model, comfy.model_base.PixArt):
236
+ diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
237
+ for k in diffusers_keys:
238
+ if k.endswith(".weight"):
239
+ to = diffusers_keys[k]
240
+ key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
241
+ key_map[key_lora] = to
242
+
243
+ key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
244
+ key_map[key_lora] = to
245
+
246
+ key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
247
+ key_map[key_lora] = to
248
+
249
+ if isinstance(model, comfy.model_base.HunyuanDiT):
250
+ for k in sdk:
251
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
252
+ key_lora = k[len("diffusion_model."):-len(".weight")]
253
+ key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
254
+
255
+ if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
256
+ diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
257
+ for k in diffusers_keys:
258
+ if k.endswith(".weight"):
259
+ to = diffusers_keys[k]
260
+ key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
261
+ key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
262
+ key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
263
+ for k in sdk:
264
+ hidden_size = model.model_config.unet_config.get("hidden_size", 0)
265
+ if k.endswith(".weight") and ".linear1." in k:
266
+ key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
267
+ # Direct mapping without diffusion_model prefix for Chroma/ChromaRadiance and similar Flux-based LoRA formats
268
+ for k in sdk:
269
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
270
+ key_lora = k[len("diffusion_model."):-len(".weight")]
271
+ key_map["{}".format(key_lora)] = k
272
+
273
+ if isinstance(model, comfy.model_base.GenmoMochi):
274
+ for k in sdk:
275
+ if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
276
+ key_lora = k[len("diffusion_model."):-len(".weight")]
277
+ key_map["{}".format(key_lora)] = k
278
+
279
+ if isinstance(model, comfy.model_base.HunyuanVideo):
280
+ for k in sdk:
281
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
282
+ # diffusion-pipe lora format
283
+ key_lora = k
284
+ key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
285
+ key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
286
+ key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
287
+ key_lora = key_lora[len("diffusion_model."):-len(".weight")]
288
+ key_map["transformer.{}".format(key_lora)] = k
289
+ key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
290
+
291
+ if isinstance(model, comfy.model_base.HiDream):
292
+ for k in sdk:
293
+ if k.startswith("diffusion_model."):
294
+ if k.endswith(".weight"):
295
+ key_lora = k[len("diffusion_model."):-len(".weight")]
296
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
297
+ key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
298
+
299
+ if isinstance(model, comfy.model_base.ACEStep):
300
+ for k in sdk:
301
+ if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
302
+ key_lora = k[len("diffusion_model."):-len(".weight")]
303
+ key_map["{}".format(key_lora)] = k
304
+
305
+ if isinstance(model, comfy.model_base.Omnigen2):
306
+ for k in sdk:
307
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
308
+ key_lora = k[len("diffusion_model."):-len(".weight")]
309
+ key_map["{}".format(key_lora)] = k
310
+
311
+ if isinstance(model, comfy.model_base.QwenImage):
312
+ for k in sdk:
313
+ if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
314
+ key_lora = k[len("diffusion_model."):-len(".weight")]
315
+ # Direct mapping for transformer_blocks format (QwenImage LoRA format)
316
+ key_map["{}".format(key_lora)] = k
317
+ # Support transformer prefix format
318
+ key_map["transformer.{}".format(key_lora)] = k
319
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
320
+
321
+ if isinstance(model, comfy.model_base.Lumina2):
322
+ diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
323
+ for k in diffusers_keys:
324
+ if k.endswith(".weight"):
325
+ to = diffusers_keys[k]
326
+ key_lora = k[:-len(".weight")]
327
+ key_map["diffusion_model.{}".format(key_lora)] = to
328
+ key_map["transformer.{}".format(key_lora)] = to
329
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
330
+
331
+ if isinstance(model, comfy.model_base.Kandinsky5):
332
+ for k in sdk:
333
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
334
+ key_lora = k[len("diffusion_model."):-len(".weight")]
335
+ key_map["{}".format(key_lora)] = k
336
+ key_map["transformer.{}".format(key_lora)] = k
337
+
338
+ return key_map
339
+
340
+
341
+ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
342
+ """
343
+ Pad a tensor to a new shape with zeros.
344
+
345
+ Args:
346
+ tensor (torch.Tensor): The original tensor to be padded.
347
+ new_shape (List[int]): The desired shape of the padded tensor.
348
+
349
+ Returns:
350
+ torch.Tensor: A new tensor padded with zeros to the specified shape.
351
+
352
+ Note:
353
+ If the new shape is smaller than the original tensor in any dimension,
354
+ the original tensor will be truncated in that dimension.
355
+ """
356
+ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
357
+ raise ValueError("The new shape must be larger than the original tensor in all dimensions")
358
+
359
+ if len(new_shape) != len(tensor.shape):
360
+ raise ValueError("The new shape must have the same number of dimensions as the original tensor")
361
+
362
+ # Create a new tensor filled with zeros
363
+ padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
364
+
365
+ # Create slicing tuples for both tensors
366
+ orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
367
+ new_slices = tuple(slice(0, dim) for dim in tensor.shape)
368
+
369
+ # Copy the original tensor into the new tensor
370
+ padded_tensor[new_slices] = tensor[orig_slices]
371
+
372
+ return padded_tensor
373
+
374
+ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
375
+ for p in patches:
376
+ strength = p[0]
377
+ v = p[1]
378
+ strength_model = p[2]
379
+ offset = p[3]
380
+ function = p[4]
381
+ if function is None:
382
+ function = lambda a: a
383
+
384
+ old_weight = None
385
+ if offset is not None:
386
+ old_weight = weight
387
+ weight = weight.narrow(offset[0], offset[1], offset[2])
388
+
389
+ if strength_model != 1.0:
390
+ weight *= strength_model
391
+
392
+ if isinstance(v, list):
393
+ v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
394
+
395
+ if isinstance(v, weight_adapter.WeightAdapterBase):
396
+ output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
397
+ if output is None:
398
+ logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
399
+ else:
400
+ weight = output
401
+ if old_weight is not None:
402
+ weight = old_weight
403
+ continue
404
+
405
+ if len(v) == 1:
406
+ patch_type = "diff"
407
+ elif len(v) == 2:
408
+ patch_type = v[0]
409
+ v = v[1]
410
+
411
+ if patch_type == "diff":
412
+ diff: torch.Tensor = v[0]
413
+ # An extra flag to pad the weight if the diff's shape is larger than the weight
414
+ do_pad_weight = len(v) > 1 and v[1]['pad_weight']
415
+ if do_pad_weight and diff.shape != weight.shape:
416
+ logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
417
+ weight = pad_tensor_to_shape(weight, diff.shape)
418
+
419
+ if strength != 0.0:
420
+ if diff.shape != weight.shape:
421
+ logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
422
+ else:
423
+ weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
424
+ elif patch_type == "set":
425
+ weight.copy_(v[0])
426
+ elif patch_type == "model_as_lora":
427
+ target_weight: torch.Tensor = v[0]
428
+ diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
429
+ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
430
+ weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
431
+ else:
432
+ logging.warning("patch type not recognized {} {}".format(patch_type, key))
433
+
434
+ if old_weight is not None:
435
+ weight = old_weight
436
+
437
+ return weight