Upload 141 files
Browse files- hugging/save_checkpoint.py +16 -0
- hugging/td_fuse/heal.py +10 -1
- hugging/td_fuse/transport.py +41 -1
hugging/save_checkpoint.py
CHANGED
|
@@ -75,6 +75,22 @@ def main():
|
|
| 75 |
print(f"Uploading latest: {latest}")
|
| 76 |
upload_checkpoint(api, latest)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
print("\nAll done! Checkpoints saved to HuggingFace.")
|
| 79 |
|
| 80 |
|
|
|
|
| 75 |
print(f"Uploading latest: {latest}")
|
| 76 |
upload_checkpoint(api, latest)
|
| 77 |
|
| 78 |
+
# Also upload perm_cache if it exists (tiny files, saves 12 min per re-run)
|
| 79 |
+
perm_cache = CKPT_DIR / "perm_cache"
|
| 80 |
+
if perm_cache.exists() and any(perm_cache.glob("*.npz")):
|
| 81 |
+
try:
|
| 82 |
+
size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024
|
| 83 |
+
print(f" Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...")
|
| 84 |
+
api.upload_folder(
|
| 85 |
+
folder_path=str(perm_cache),
|
| 86 |
+
path_in_repo="perm_cache",
|
| 87 |
+
repo_id=REPO,
|
| 88 |
+
commit_message="Permutation cache (saves 12 min Sinkhorn)",
|
| 89 |
+
)
|
| 90 |
+
print(f" Done: perm_cache")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f" WARNING: perm_cache upload failed ({e})")
|
| 93 |
+
|
| 94 |
print("\nAll done! Checkpoints saved to HuggingFace.")
|
| 95 |
|
| 96 |
|
hugging/td_fuse/heal.py
CHANGED
|
@@ -347,6 +347,15 @@ def apply_qlora_standard(
|
|
| 347 |
print("\n[heal] Starting standard QLoRA healing fine-tune...")
|
| 348 |
trainer.train()
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
# Save — merge LoRA adapters
|
| 351 |
healed_dir = Path(cfg.output_dir) / "healed"
|
| 352 |
healed_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -354,7 +363,7 @@ def apply_qlora_standard(
|
|
| 354 |
print(f"\n[heal] Merging LoRA adapters...")
|
| 355 |
merged_model = model.merge_and_unload()
|
| 356 |
|
| 357 |
-
|
| 358 |
|
| 359 |
# SAVE FIRST — never delete anything until save is confirmed
|
| 360 |
# save_pretrained can fail on 4-bit merged models (NotImplementedError)
|
|
|
|
| 347 |
print("\n[heal] Starting standard QLoRA healing fine-tune...")
|
| 348 |
trainer.train()
|
| 349 |
|
| 350 |
+
# Free disk space: delete training checkpoints (epoch saves) before saving final model
|
| 351 |
+
# These are ~17GB and we need room for the healed model
|
| 352 |
+
import shutil, gc
|
| 353 |
+
heal_output_dir = Path(cfg.output_dir) / "heal_output"
|
| 354 |
+
if heal_output_dir.exists():
|
| 355 |
+
print(f"[heal] Cleaning up training checkpoints to free disk space...")
|
| 356 |
+
shutil.rmtree(str(heal_output_dir), ignore_errors=True)
|
| 357 |
+
print(f"[heal] Freed ~17GB from {heal_output_dir}")
|
| 358 |
+
|
| 359 |
# Save — merge LoRA adapters
|
| 360 |
healed_dir = Path(cfg.output_dir) / "healed"
|
| 361 |
healed_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 363 |
print(f"\n[heal] Merging LoRA adapters...")
|
| 364 |
merged_model = model.merge_and_unload()
|
| 365 |
|
| 366 |
+
gc.collect()
|
| 367 |
|
| 368 |
# SAVE FIRST — never delete anything until save is confirmed
|
| 369 |
# save_pretrained can fail on 4-bit merged models (NotImplementedError)
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -391,7 +391,36 @@ def _compute_plans_fallback(
|
|
| 391 |
else:
|
| 392 |
corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
|
| 393 |
print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
sys.stdout.flush()
|
| 396 |
|
| 397 |
# Track which block indices already have permutations (avoid computing twice)
|
|
@@ -465,6 +494,17 @@ def _compute_plans_fallback(
|
|
| 465 |
|
| 466 |
if permutations:
|
| 467 |
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 469 |
tracker.done()
|
| 470 |
sys.stdout.flush()
|
|
|
|
| 391 |
else:
|
| 392 |
corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
|
| 393 |
print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
|
| 394 |
+
|
| 395 |
+
# Check for cached permutations (saves ~12 min per re-run)
|
| 396 |
+
# Look in both local checkpoint dir AND HuggingFace download location
|
| 397 |
+
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 398 |
+
src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
|
| 399 |
+
cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
|
| 400 |
+
hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
|
| 401 |
+
if not cache_file.exists() and hf_cache_file.exists():
|
| 402 |
+
cache_file = hf_cache_file # Use HuggingFace-downloaded cache
|
| 403 |
+
if cache_file.exists():
|
| 404 |
+
print(f"[transport] LOADING CACHED permutations from {cache_file}")
|
| 405 |
+
cached = np.load(str(cache_file), allow_pickle=True)
|
| 406 |
+
for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
|
| 407 |
+
key = f"{sl}__{tl}"
|
| 408 |
+
if key in cached:
|
| 409 |
+
permutations[(sl, tl)] = cached[key]
|
| 410 |
+
Q_matrices[(sl, tl)] = np.eye(S0.shape[1]) / S0.shape[1]
|
| 411 |
+
tracker.tick(f"{sl} -> {tl}")
|
| 412 |
+
print(f"[transport] Loaded {len(permutations)} cached permutations (skipped Sinkhorn!)")
|
| 413 |
+
tracker.done()
|
| 414 |
+
sys.stdout.flush()
|
| 415 |
+
return {
|
| 416 |
+
"P": P,
|
| 417 |
+
"Q": Q_matrices,
|
| 418 |
+
"permutations": permutations,
|
| 419 |
+
"source_layers": source_layers,
|
| 420 |
+
"target_layers": target_layers,
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
print("[transport] No cache found — computing fresh (will cache for next time)...")
|
| 424 |
sys.stdout.flush()
|
| 425 |
|
| 426 |
# Track which block indices already have permutations (avoid computing twice)
|
|
|
|
| 494 |
|
| 495 |
if permutations:
|
| 496 |
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
| 497 |
+
# Cache permutations so we don't recompute on re-runs (~12 min saved)
|
| 498 |
+
try:
|
| 499 |
+
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 500 |
+
perm_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 501 |
+
src_name = "_".join(sorted(source_act.keys())[:3])
|
| 502 |
+
cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
|
| 503 |
+
save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
|
| 504 |
+
np.savez_compressed(str(cache_file), **save_dict)
|
| 505 |
+
print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")
|
| 506 |
+
except Exception as e:
|
| 507 |
+
print(f"[transport] WARNING: Could not cache permutations ({e})")
|
| 508 |
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 509 |
tracker.done()
|
| 510 |
sys.stdout.flush()
|