td-builder commited on
Commit
d839b05
·
verified ·
1 Parent(s): 815dfda

Upload 141 files

Browse files
hugging/save_checkpoint.py CHANGED
@@ -75,6 +75,22 @@ def main():
75
  print(f"Uploading latest: {latest}")
76
  upload_checkpoint(api, latest)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  print("\nAll done! Checkpoints saved to HuggingFace.")
79
 
80
 
 
75
  print(f"Uploading latest: {latest}")
76
  upload_checkpoint(api, latest)
77
 
78
+ # Also upload perm_cache if it exists (tiny files, saves 12 min per re-run)
79
+ perm_cache = CKPT_DIR / "perm_cache"
80
+ if perm_cache.exists() and any(perm_cache.glob("*.npz")):
81
+ try:
82
+ size_kb = sum(f.stat().st_size for f in perm_cache.rglob("*") if f.is_file()) / 1024
83
+ print(f" Uploading perm_cache ({size_kb:.0f} KB) to {REPO}/perm_cache/...")
84
+ api.upload_folder(
85
+ folder_path=str(perm_cache),
86
+ path_in_repo="perm_cache",
87
+ repo_id=REPO,
88
+ commit_message="Permutation cache (saves 12 min Sinkhorn)",
89
+ )
90
+ print(f" Done: perm_cache")
91
+ except Exception as e:
92
+ print(f" WARNING: perm_cache upload failed ({e})")
93
+
94
  print("\nAll done! Checkpoints saved to HuggingFace.")
95
 
96
 
hugging/td_fuse/heal.py CHANGED
@@ -347,6 +347,15 @@ def apply_qlora_standard(
347
  print("\n[heal] Starting standard QLoRA healing fine-tune...")
348
  trainer.train()
349
 
 
 
 
 
 
 
 
 
 
350
  # Save — merge LoRA adapters
351
  healed_dir = Path(cfg.output_dir) / "healed"
352
  healed_dir.mkdir(parents=True, exist_ok=True)
@@ -354,7 +363,7 @@ def apply_qlora_standard(
354
  print(f"\n[heal] Merging LoRA adapters...")
355
  merged_model = model.merge_and_unload()
356
 
357
- import shutil, gc
358
 
359
  # SAVE FIRST — never delete anything until save is confirmed
360
  # save_pretrained can fail on 4-bit merged models (NotImplementedError)
 
347
  print("\n[heal] Starting standard QLoRA healing fine-tune...")
348
  trainer.train()
349
 
350
+ # Free disk space: delete training checkpoints (epoch saves) before saving final model
351
+ # These are ~17GB and we need room for the healed model
352
+ import shutil, gc
353
+ heal_output_dir = Path(cfg.output_dir) / "heal_output"
354
+ if heal_output_dir.exists():
355
+ print(f"[heal] Cleaning up training checkpoints to free disk space...")
356
+ shutil.rmtree(str(heal_output_dir), ignore_errors=True)
357
+ print(f"[heal] Freed ~17GB from {heal_output_dir}")
358
+
359
  # Save — merge LoRA adapters
360
  healed_dir = Path(cfg.output_dir) / "healed"
361
  healed_dir.mkdir(parents=True, exist_ok=True)
 
363
  print(f"\n[heal] Merging LoRA adapters...")
364
  merged_model = model.merge_and_unload()
365
 
366
+ gc.collect()
367
 
368
  # SAVE FIRST — never delete anything until save is confirmed
369
  # save_pretrained can fail on 4-bit merged models (NotImplementedError)
hugging/td_fuse/transport.py CHANGED
@@ -391,7 +391,36 @@ def _compute_plans_fallback(
391
  else:
392
  corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
393
  print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
394
- print("[transport] This may take 2-5 minutes...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  sys.stdout.flush()
396
 
397
  # Track which block indices already have permutations (avoid computing twice)
@@ -465,6 +494,17 @@ def _compute_plans_fallback(
465
 
466
  if permutations:
467
  print(f"[transport] Computed {len(permutations)} neuron permutations")
 
 
 
 
 
 
 
 
 
 
 
468
  print(f"[transport] Direct matching complete: {n_source} layer pairs")
469
  tracker.done()
470
  sys.stdout.flush()
 
391
  else:
392
  corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
393
  print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
394
+
395
+ # Check for cached permutations (saves ~12 min per re-run)
396
+ # Look in both local checkpoint dir AND HuggingFace download location
397
+ perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
398
+ src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
399
+ cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
400
+ hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
401
+ if not cache_file.exists() and hf_cache_file.exists():
402
+ cache_file = hf_cache_file # Use HuggingFace-downloaded cache
403
+ if cache_file.exists():
404
+ print(f"[transport] LOADING CACHED permutations from {cache_file}")
405
+ cached = np.load(str(cache_file), allow_pickle=True)
406
+ for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
407
+ key = f"{sl}__{tl}"
408
+ if key in cached:
409
+ permutations[(sl, tl)] = cached[key]
410
+ Q_matrices[(sl, tl)] = np.eye(S0.shape[1]) / S0.shape[1]
411
+ tracker.tick(f"{sl} -> {tl}")
412
+ print(f"[transport] Loaded {len(permutations)} cached permutations (skipped Sinkhorn!)")
413
+ tracker.done()
414
+ sys.stdout.flush()
415
+ return {
416
+ "P": P,
417
+ "Q": Q_matrices,
418
+ "permutations": permutations,
419
+ "source_layers": source_layers,
420
+ "target_layers": target_layers,
421
+ }
422
+
423
+ print("[transport] No cache found — computing fresh (will cache for next time)...")
424
  sys.stdout.flush()
425
 
426
  # Track which block indices already have permutations (avoid computing twice)
 
494
 
495
  if permutations:
496
  print(f"[transport] Computed {len(permutations)} neuron permutations")
497
+ # Cache permutations so we don't recompute on re-runs (~12 min saved)
498
+ try:
499
+ perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
500
+ perm_cache_dir.mkdir(parents=True, exist_ok=True)
501
+ src_name = "_".join(sorted(source_act.keys())[:3])
502
+ cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
503
+ save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
504
+ np.savez_compressed(str(cache_file), **save_dict)
505
+ print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")
506
+ except Exception as e:
507
+ print(f"[transport] WARNING: Could not cache permutations ({e})")
508
  print(f"[transport] Direct matching complete: {n_source} layer pairs")
509
  tracker.done()
510
  sys.stdout.flush()