td-builder commited on
Commit
dd4db03
·
verified ·
1 Parent(s): 9295327

Upload 141 files

Browse files
Files changed (1) hide show
  1. hugging/td_fuse/transport.py +4 -3
hugging/td_fuse/transport.py CHANGED
@@ -23,6 +23,7 @@ Findings: #01, #07, #24
23
 
24
  import sys
25
  import time
 
26
  import torch
27
  import numpy as np
28
  from pathlib import Path
@@ -396,8 +397,8 @@ def _compute_plans_fallback(
396
  # Look in both local checkpoint dir AND HuggingFace download location
397
  perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
398
  src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
399
- cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
400
- hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
401
  if not cache_file.exists() and hf_cache_file.exists():
402
  cache_file = hf_cache_file # Use HuggingFace-downloaded cache
403
  if cache_file.exists():
@@ -499,7 +500,7 @@ def _compute_plans_fallback(
499
  perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
500
  perm_cache_dir.mkdir(parents=True, exist_ok=True)
501
  src_name = "_".join(sorted(source_act.keys())[:3])
502
- cache_file = perm_cache_dir / f"perms_{n_source}_{hash(src_name) % 10**8}.npz"
503
  save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
504
  np.savez_compressed(str(cache_file), **save_dict)
505
  print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")
 
23
 
24
  import sys
25
  import time
26
+ import hashlib
27
  import torch
28
  import numpy as np
29
  from pathlib import Path
 
397
  # Look in both local checkpoint dir AND HuggingFace download location
398
  perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
399
  src_name = "_".join(sorted(source_act.keys())[:3]) # first 3 layer names as key
400
+ cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
401
+ hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
402
  if not cache_file.exists() and hf_cache_file.exists():
403
  cache_file = hf_cache_file # Use HuggingFace-downloaded cache
404
  if cache_file.exists():
 
500
  perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache"
501
  perm_cache_dir.mkdir(parents=True, exist_ok=True)
502
  src_name = "_".join(sorted(source_act.keys())[:3])
503
+ cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz"
504
  save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()}
505
  np.savez_compressed(str(cache_file), **save_dict)
506
  print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)")