Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e6a0a9e
1
Parent(s):
afcc9e2
update
Browse files
videox_fun/utils/lora_utils.py
CHANGED
|
@@ -389,11 +389,28 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|
| 389 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 390 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 391 |
key = key.replace(".ffn.", "_ffn_")
|
|
|
|
|
|
|
| 392 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 393 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 394 |
key = key.replace(".lora_A.weight", ".lora_down.weight")
|
| 395 |
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
updates[layer][elem] = value
|
| 398 |
|
| 399 |
sequential_cpu_offload_flag = False
|
|
@@ -512,11 +529,28 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|
| 512 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 513 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 514 |
key = key.replace(".ffn.", "_ffn_")
|
|
|
|
|
|
|
| 515 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 516 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 517 |
key = key.replace(".lora_A.weight", ".lora_down.weight")
|
| 518 |
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
updates[layer][elem] = value
|
| 521 |
|
| 522 |
sequential_cpu_offload_flag = False
|
|
|
|
| 389 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 390 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 391 |
key = key.replace(".ffn.", "_ffn_")
|
| 392 |
+
key = key.replace("text_embedding.", "text_embedding_")
|
| 393 |
+
key = key.replace("time_embedding.", "time_embedding_")
|
| 394 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 395 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 396 |
key = key.replace(".lora_A.weight", ".lora_down.weight")
|
| 397 |
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 398 |
+
|
| 399 |
+
if key.endswith(".lora_down.weight"):
|
| 400 |
+
layer = key[:-len(".lora_down.weight")]
|
| 401 |
+
elem = "lora_down.weight"
|
| 402 |
+
elif key.endswith(".lora_up.weight"):
|
| 403 |
+
layer = key[:-len(".lora_up.weight")]
|
| 404 |
+
elem = "lora_up.weight"
|
| 405 |
+
elif key.endswith(".alpha"):
|
| 406 |
+
layer = key[:-len(".alpha")]
|
| 407 |
+
elem = "alpha"
|
| 408 |
+
else:
|
| 409 |
+
continue
|
| 410 |
+
|
| 411 |
+
if layer.endswith("."):
|
| 412 |
+
layer = layer[:-1]
|
| 413 |
+
|
| 414 |
updates[layer][elem] = value
|
| 415 |
|
| 416 |
sequential_cpu_offload_flag = False
|
|
|
|
| 529 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 530 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 531 |
key = key.replace(".ffn.", "_ffn_")
|
| 532 |
+
key = key.replace("text_embedding.", "text_embedding_")
|
| 533 |
+
key = key.replace("time_embedding.", "time_embedding_")
|
| 534 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 535 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 536 |
key = key.replace(".lora_A.weight", ".lora_down.weight")
|
| 537 |
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 538 |
+
|
| 539 |
+
if key.endswith(".lora_down.weight"):
|
| 540 |
+
layer = key[:-len(".lora_down.weight")]
|
| 541 |
+
elem = "lora_down.weight"
|
| 542 |
+
elif key.endswith(".lora_up.weight"):
|
| 543 |
+
layer = key[:-len(".lora_up.weight")]
|
| 544 |
+
elem = "lora_up.weight"
|
| 545 |
+
elif key.endswith(".alpha"):
|
| 546 |
+
layer = key[:-len(".alpha")]
|
| 547 |
+
elem = "alpha"
|
| 548 |
+
else:
|
| 549 |
+
continue
|
| 550 |
+
|
| 551 |
+
if layer.endswith("."):
|
| 552 |
+
layer = layer[:-1]
|
| 553 |
+
|
| 554 |
updates[layer][elem] = value
|
| 555 |
|
| 556 |
sequential_cpu_offload_flag = False
|