XiangpengYang commited on
Commit
e6a0a9e
·
1 Parent(s): afcc9e2
Files changed (1) hide show
  1. videox_fun/utils/lora_utils.py +36 -2
videox_fun/utils/lora_utils.py CHANGED
@@ -389,11 +389,28 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
389
  key = key.replace(".self_attn.", "_self_attn_")
390
  key = key.replace(".cross_attn.", "_cross_attn_")
391
  key = key.replace(".ffn.", "_ffn_")
 
 
392
  key = key.replace(".lora_A.default.", ".lora_down.")
393
  key = key.replace(".lora_B.default.", ".lora_up.")
394
  key = key.replace(".lora_A.weight", ".lora_down.weight")
395
  key = key.replace(".lora_B.weight", ".lora_up.weight")
396
- layer, elem = key.split('.', 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  updates[layer][elem] = value
398
 
399
  sequential_cpu_offload_flag = False
@@ -512,11 +529,28 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
512
  key = key.replace(".self_attn.", "_self_attn_")
513
  key = key.replace(".cross_attn.", "_cross_attn_")
514
  key = key.replace(".ffn.", "_ffn_")
 
 
515
  key = key.replace(".lora_A.default.", ".lora_down.")
516
  key = key.replace(".lora_B.default.", ".lora_up.")
517
  key = key.replace(".lora_A.weight", ".lora_down.weight")
518
  key = key.replace(".lora_B.weight", ".lora_up.weight")
519
- layer, elem = key.split('.', 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  updates[layer][elem] = value
521
 
522
  sequential_cpu_offload_flag = False
 
389
  key = key.replace(".self_attn.", "_self_attn_")
390
  key = key.replace(".cross_attn.", "_cross_attn_")
391
  key = key.replace(".ffn.", "_ffn_")
392
+ key = key.replace("text_embedding.", "text_embedding_")
393
+ key = key.replace("time_embedding.", "time_embedding_")
394
  key = key.replace(".lora_A.default.", ".lora_down.")
395
  key = key.replace(".lora_B.default.", ".lora_up.")
396
  key = key.replace(".lora_A.weight", ".lora_down.weight")
397
  key = key.replace(".lora_B.weight", ".lora_up.weight")
398
+
399
+ if key.endswith(".lora_down.weight"):
400
+ layer = key[:-len(".lora_down.weight")]
401
+ elem = "lora_down.weight"
402
+ elif key.endswith(".lora_up.weight"):
403
+ layer = key[:-len(".lora_up.weight")]
404
+ elem = "lora_up.weight"
405
+ elif key.endswith(".alpha"):
406
+ layer = key[:-len(".alpha")]
407
+ elem = "alpha"
408
+ else:
409
+ continue
410
+
411
+ if layer.endswith("."):
412
+ layer = layer[:-1]
413
+
414
  updates[layer][elem] = value
415
 
416
  sequential_cpu_offload_flag = False
 
529
  key = key.replace(".self_attn.", "_self_attn_")
530
  key = key.replace(".cross_attn.", "_cross_attn_")
531
  key = key.replace(".ffn.", "_ffn_")
532
+ key = key.replace("text_embedding.", "text_embedding_")
533
+ key = key.replace("time_embedding.", "time_embedding_")
534
  key = key.replace(".lora_A.default.", ".lora_down.")
535
  key = key.replace(".lora_B.default.", ".lora_up.")
536
  key = key.replace(".lora_A.weight", ".lora_down.weight")
537
  key = key.replace(".lora_B.weight", ".lora_up.weight")
538
+
539
+ if key.endswith(".lora_down.weight"):
540
+ layer = key[:-len(".lora_down.weight")]
541
+ elem = "lora_down.weight"
542
+ elif key.endswith(".lora_up.weight"):
543
+ layer = key[:-len(".lora_up.weight")]
544
+ elem = "lora_up.weight"
545
+ elif key.endswith(".alpha"):
546
+ layer = key[:-len(".alpha")]
547
+ elem = "alpha"
548
+ else:
549
+ continue
550
+
551
+ if layer.endswith("."):
552
+ layer = layer[:-1]
553
+
554
  updates[layer][elem] = value
555
 
556
  sequential_cpu_offload_flag = False