Upload 141 files
Browse files
hugging/td_fuse/transport.py
CHANGED
|
@@ -23,6 +23,7 @@ Findings: #01, #07, #24
|
|
| 23 |
|
| 24 |
import sys
|
| 25 |
import time
|
|
|
|
| 26 |
import torch
|
| 27 |
import numpy as np
|
| 28 |
from pathlib import Path
|
|
@@ -396,8 +397,8 @@ def _compute_plans_fallback(
|
|
| 396 |
# Look in both local checkpoint dir AND HuggingFace download location
|
| 397 |
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 398 |
src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
|
| 399 |
-
cache_file = perm_cache_dir / f"perms_{n_source}_{
|
| 400 |
-
hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{
|
| 401 |
if not cache_file.exists() and hf_cache_file.exists():
|
| 402 |
cache_file = hf_cache_file # Use HuggingFace-downloaded cache
|
| 403 |
if cache_file.exists():
|
|
@@ -499,7 +500,7 @@ def _compute_plans_fallback(
|
|
| 499 |
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 500 |
perm_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 501 |
src_name = "_".join(sorted(source_act.keys())[:3])
|
| 502 |
-
cache_file = perm_cache_dir / f"perms_{n_source}_{
|
| 503 |
save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
|
| 504 |
np.savez_compressed(str(cache_file), **save_dict)
|
| 505 |
print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")
|
|
|
|
| 23 |
|
| 24 |
import sys
|
| 25 |
import time
|
| 26 |
+
import hashlib
|
| 27 |
import torch
|
| 28 |
import numpy as np
|
| 29 |
from pathlib import Path
|
|
|
|
| 397 |
# Look in both local checkpoint dir AND HuggingFace download location
|
| 398 |
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 399 |
src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
|
| 400 |
+
cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
|
| 401 |
+
hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
|
| 402 |
if not cache_file.exists() and hf_cache_file.exists():
|
| 403 |
cache_file = hf_cache_file # Use HuggingFace-downloaded cache
|
| 404 |
if cache_file.exists():
|
|
|
|
| 500 |
perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
|
| 501 |
perm_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 502 |
src_name = "_".join(sorted(source_act.keys())[:3])
|
| 503 |
+
cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
|
| 504 |
save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
|
| 505 |
np.savez_compressed(str(cache_file), **save_dict)
|
| 506 |
print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")
|