Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -587,6 +587,7 @@ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, o
|
|
| 587 |
if "down" in g and "up" in g:
|
| 588 |
down, up = g["down"].float(), g["up"].float()
|
| 589 |
|
|
|
|
| 590 |
if len(down.shape) == 4:
|
| 591 |
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
|
| 592 |
flat = merged.flatten(1)
|
|
@@ -594,37 +595,45 @@ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, o
|
|
| 594 |
merged = up @ down
|
| 595 |
flat = merged
|
| 596 |
|
| 597 |
-
# FAST SVD (svd_lowrank)
|
| 598 |
target_rank = int(new_rank)
|
| 599 |
-
# Add buffer to q to ensure convergence
|
| 600 |
q = min(target_rank + 10, min(flat.shape))
|
| 601 |
|
| 602 |
U, S, V = torch.svd_lowrank(flat, q=q)
|
| 603 |
-
|
| 604 |
-
Vh = V.t()
|
| 605 |
|
| 606 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
U = U[:, :target_rank]
|
| 608 |
S = S[:target_rank]
|
| 609 |
Vh = Vh[:target_rank, :]
|
| 610 |
|
| 611 |
-
# Reconstruct
|
| 612 |
U = U @ torch.diag(S)
|
| 613 |
|
| 614 |
if len(down.shape) == 4:
|
| 615 |
U = U.reshape(up.shape[0], target_rank, 1, 1)
|
| 616 |
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
|
| 617 |
|
| 618 |
-
|
| 619 |
-
new_state[f"{stem}.
|
|
|
|
| 620 |
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
|
| 621 |
|
| 622 |
out = TempDir / "resized.safetensors"
|
|
|
|
| 623 |
save_file(new_state, out)
|
|
|
|
| 624 |
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
|
| 625 |
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
|
| 626 |
return "Done"
|
| 627 |
-
|
| 628 |
# =================================================================================
|
| 629 |
# UI
|
| 630 |
# =================================================================================
|
|
|
|
| 587 |
if "down" in g and "up" in g:
|
| 588 |
down, up = g["down"].float(), g["up"].float()
|
| 589 |
|
| 590 |
+
# 1. Merge Up/Down
|
| 591 |
if len(down.shape) == 4:
|
| 592 |
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
|
| 593 |
flat = merged.flatten(1)
|
|
|
|
| 595 |
merged = up @ down
|
| 596 |
flat = merged
|
| 597 |
|
| 598 |
+
# 2. FAST SVD (svd_lowrank)
|
| 599 |
target_rank = int(new_rank)
|
| 600 |
+
# Add buffer to q to ensure convergence
|
| 601 |
q = min(target_rank + 10, min(flat.shape))
|
| 602 |
|
| 603 |
U, S, V = torch.svd_lowrank(flat, q=q)
|
| 604 |
+
Vh = V.t()
|
|
|
|
| 605 |
|
| 606 |
+
# 3. Dynamic Rank Selection
|
| 607 |
+
if dynamic_method == "sv_ratio":
|
| 608 |
+
target_rank = index_sv_ratio(S, dynamic_param)
|
| 609 |
+
|
| 610 |
+
# Hard limit by user's max rank
|
| 611 |
+
target_rank = min(target_rank, int(new_rank), S.shape[0])
|
| 612 |
+
|
| 613 |
+
# 4. Truncate
|
| 614 |
U = U[:, :target_rank]
|
| 615 |
S = S[:target_rank]
|
| 616 |
Vh = Vh[:target_rank, :]
|
| 617 |
|
| 618 |
+
# 5. Reconstruct Up Matrix
|
| 619 |
U = U @ torch.diag(S)
|
| 620 |
|
| 621 |
if len(down.shape) == 4:
|
| 622 |
U = U.reshape(up.shape[0], target_rank, 1, 1)
|
| 623 |
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
|
| 624 |
|
| 625 |
+
# 6. Save (FIX: Enforce contiguous memory layout)
|
| 626 |
+
new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
|
| 627 |
+
new_state[f"{stem}.lora_up.weight"] = U.contiguous()
|
| 628 |
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
|
| 629 |
|
| 630 |
out = TempDir / "resized.safetensors"
|
| 631 |
+
# safetensors requires contiguous tensors
|
| 632 |
save_file(new_state, out)
|
| 633 |
+
|
| 634 |
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
|
| 635 |
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
|
| 636 |
return "Done"
|
|
|
|
| 637 |
# =================================================================================
|
| 638 |
# UI
|
| 639 |
# =================================================================================
|