Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -454,42 +454,111 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
|
|
| 454 |
# TAB 2: EXTRACT LORA
|
| 455 |
# =================================================================================
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
|
| 458 |
org = MemoryEfficientSafeOpen(model_org)
|
| 459 |
tuned = MemoryEfficientSafeOpen(model_tuned)
|
| 460 |
lora_sd = {}
|
| 461 |
-
print("Calculating diffs...")
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
mat_org = org.get_tensor(key).float()
|
| 465 |
mat_tuned = tuned.get_tensor(key).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
diff = mat_tuned - mat_org
|
|
|
|
|
|
|
| 467 |
if torch.max(torch.abs(diff)) < 1e-4: continue
|
| 468 |
|
| 469 |
-
out_dim
|
|
|
|
|
|
|
| 470 |
r = min(rank, in_dim, out_dim)
|
|
|
|
| 471 |
is_conv = len(diff.shape) == 4
|
| 472 |
if is_conv: diff = diff.flatten(start_dim=1)
|
|
|
|
| 473 |
|
| 474 |
try:
|
| 475 |
-
|
| 476 |
-
U, S,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
U = U @ torch.diag(S)
|
|
|
|
|
|
|
| 478 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 479 |
-
hi_val = torch.quantile(dist, clamp)
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
| 482 |
if is_conv:
|
| 483 |
U = U.reshape(out_dim, r, 1, 1)
|
| 484 |
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
|
| 485 |
else:
|
| 486 |
U = U.reshape(out_dim, r)
|
| 487 |
Vh = Vh.reshape(r, in_dim)
|
|
|
|
| 488 |
stem = key.replace(".weight", "")
|
| 489 |
-
lora_sd[f"{stem}.lora_up.weight"] = U
|
| 490 |
-
lora_sd[f"{stem}.lora_down.weight"] = Vh
|
| 491 |
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
|
| 492 |
-
except
|
|
|
|
|
|
|
|
|
|
| 493 |
out = TempDir / "extracted.safetensors"
|
| 494 |
save_file(lora_sd, out)
|
| 495 |
return str(out)
|
|
@@ -498,12 +567,16 @@ def task_extract(hf_token, org, tun, rank, out):
|
|
| 498 |
cleanup_temp()
|
| 499 |
if hf_token: login(hf_token.strip())
|
| 500 |
try:
|
| 501 |
-
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
| 503 |
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
|
|
|
|
| 504 |
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
|
| 505 |
-
api.upload_file(path_or_fileobj=f, path_in_repo="
|
| 506 |
-
return "Done"
|
| 507 |
except Exception as e: return f"Error: {e}"
|
| 508 |
|
| 509 |
# =================================================================================
|
|
|
|
| 454 |
# TAB 2: EXTRACT LORA
|
| 455 |
# =================================================================================
|
| 456 |
|
| 457 |
+
def identify_and_download_model(repo_id, token):
|
| 458 |
+
"""
|
| 459 |
+
Smart download: checks for diffusers format (unet/transformer) vs standard safetensors.
|
| 460 |
+
"""
|
| 461 |
+
print(f"Scanning {repo_id} for model weights...")
|
| 462 |
+
files = list_repo_files(repo_id=repo_id, token=token)
|
| 463 |
+
|
| 464 |
+
# Priority list for diffusers vs single file
|
| 465 |
+
priorities = [
|
| 466 |
+
"transformer/diffusion_pytorch_model.safetensors",
|
| 467 |
+
"unet/diffusion_pytorch_model.safetensors",
|
| 468 |
+
"model.safetensors",
|
| 469 |
+
# Fallback to any safetensors that isn't an adapter or lora
|
| 470 |
+
lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f
|
| 471 |
+
]
|
| 472 |
+
|
| 473 |
+
target_file = None
|
| 474 |
+
for p in priorities:
|
| 475 |
+
if callable(p):
|
| 476 |
+
candidates = [f for f in files if p(f)]
|
| 477 |
+
if candidates:
|
| 478 |
+
target_file = candidates[0]
|
| 479 |
+
break
|
| 480 |
+
elif p in files:
|
| 481 |
+
target_file = p
|
| 482 |
+
break
|
| 483 |
+
|
| 484 |
+
if not target_file:
|
| 485 |
+
raise ValueError(f"Could not find a valid model weight file in {repo_id}. Ensure it contains .safetensors weights.")
|
| 486 |
+
|
| 487 |
+
print(f"Downloading main weight file: {target_file}")
|
| 488 |
+
hf_hub_download(repo_id=repo_id, filename=target_file, token=token, local_dir=TempDir)
|
| 489 |
+
|
| 490 |
+
# Locate actual path
|
| 491 |
+
found = list(TempDir.rglob(os.path.basename(target_file)))[0]
|
| 492 |
+
return found
|
| 493 |
+
|
| 494 |
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
|
| 495 |
org = MemoryEfficientSafeOpen(model_org)
|
| 496 |
tuned = MemoryEfficientSafeOpen(model_tuned)
|
| 497 |
lora_sd = {}
|
| 498 |
+
print("Calculating diffs & extracting LoRA...")
|
| 499 |
+
|
| 500 |
+
# Get intersection of keys
|
| 501 |
+
keys = set(org.keys()).intersection(set(tuned.keys()))
|
| 502 |
+
|
| 503 |
+
for key in tqdm(keys, desc="Extracting"):
|
| 504 |
+
# Skip integer buffers/metadata
|
| 505 |
+
if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key:
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
mat_org = org.get_tensor(key).float()
|
| 509 |
mat_tuned = tuned.get_tensor(key).float()
|
| 510 |
+
|
| 511 |
+
# Skip if shapes mismatch (shouldn't happen if models match)
|
| 512 |
+
if mat_org.shape != mat_tuned.shape: continue
|
| 513 |
+
|
| 514 |
diff = mat_tuned - mat_org
|
| 515 |
+
|
| 516 |
+
# Skip if no difference
|
| 517 |
if torch.max(torch.abs(diff)) < 1e-4: continue
|
| 518 |
|
| 519 |
+
out_dim = diff.shape[0]
|
| 520 |
+
in_dim = diff.shape[1] if len(diff.shape) > 1 else 1
|
| 521 |
+
|
| 522 |
r = min(rank, in_dim, out_dim)
|
| 523 |
+
|
| 524 |
is_conv = len(diff.shape) == 4
|
| 525 |
if is_conv: diff = diff.flatten(start_dim=1)
|
| 526 |
+
elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed, though rarely lora'd
|
| 527 |
|
| 528 |
try:
|
| 529 |
+
# Use svd_lowrank for massive speedup on CPU vs linalg.svd
|
| 530 |
+
U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
|
| 531 |
+
Vh = V.t()
|
| 532 |
+
|
| 533 |
+
U = U[:, :r]
|
| 534 |
+
S = S[:r]
|
| 535 |
+
Vh = Vh[:r, :]
|
| 536 |
+
|
| 537 |
+
# Merge S into U for standard LoRA format
|
| 538 |
U = U @ torch.diag(S)
|
| 539 |
+
|
| 540 |
+
# Clamp outliers
|
| 541 |
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 542 |
+
hi_val = torch.quantile(torch.abs(dist), clamp)
|
| 543 |
+
if hi_val > 0:
|
| 544 |
+
U = U.clamp(-hi_val, hi_val)
|
| 545 |
+
Vh = Vh.clamp(-hi_val, hi_val)
|
| 546 |
+
|
| 547 |
if is_conv:
|
| 548 |
U = U.reshape(out_dim, r, 1, 1)
|
| 549 |
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
|
| 550 |
else:
|
| 551 |
U = U.reshape(out_dim, r)
|
| 552 |
Vh = Vh.reshape(r, in_dim)
|
| 553 |
+
|
| 554 |
stem = key.replace(".weight", "")
|
| 555 |
+
lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
|
| 556 |
+
lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
|
| 557 |
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
|
| 558 |
+
except Exception as e:
|
| 559 |
+
print(f"Skipping {key} due to error: {e}")
|
| 560 |
+
pass
|
| 561 |
+
|
| 562 |
out = TempDir / "extracted.safetensors"
|
| 563 |
save_file(lora_sd, out)
|
| 564 |
return str(out)
|
|
|
|
| 567 |
cleanup_temp()
|
| 568 |
if hf_token: login(hf_token.strip())
|
| 569 |
try:
|
| 570 |
+
print("Downloading Original Model...")
|
| 571 |
+
p1 = identify_and_download_model(org, hf_token)
|
| 572 |
+
print("Downloading Tuned Model...")
|
| 573 |
+
p2 = identify_and_download_model(tun, hf_token)
|
| 574 |
+
|
| 575 |
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
|
| 576 |
+
|
| 577 |
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
|
| 578 |
+
api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
|
| 579 |
+
return "Done! Extracted to " + out
|
| 580 |
except Exception as e: return f"Error: {e}"
|
| 581 |
|
| 582 |
# =================================================================================
|