Spaces:
Runtime error
Runtime error
| """Triton cache persistence via HF Hub. | |
| Call setup() BEFORE importing triton/mamba_ssm to hydrate the cache. | |
| Call teardown() AFTER training to push the (possibly updated) cache. | |
| """ | |
| import os | |
| from pathlib import Path | |
| GPU_PROFILE = os.environ.get("FEATHER_GPU_PROFILE", os.environ.get("FEATHER_HF_FLAVOR", "a10g-large")) | |
| TRITON_CACHE_DIR = os.environ.get("TRITON_CACHE_DIR", f"/workspace/triton_cache/{GPU_PROFILE}") | |
| CACHE_REPO = os.environ.get("TRITON_CACHE_REPO", f"icarus112/feather-triton-cache-{GPU_PROFILE}") | |
| def setup() -> None: | |
| os.makedirs(TRITON_CACHE_DIR, exist_ok=True) | |
| os.environ["TRITON_CACHE_DIR"] = TRITON_CACHE_DIR | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("[triton_cache] no HF_TOKEN; skipping cache hydrate", flush=True) | |
| return | |
| try: | |
| from huggingface_hub import HfApi, snapshot_download, create_repo | |
| api = HfApi(token=token) | |
| create_repo(CACHE_REPO, repo_type="dataset", private=True, exist_ok=True, token=token) | |
| snapshot_download( | |
| repo_id=CACHE_REPO, | |
| repo_type="dataset", | |
| local_dir=TRITON_CACHE_DIR, | |
| token=token, | |
| ) | |
| n = sum(1 for p in Path(TRITON_CACHE_DIR).rglob("*") if p.is_file()) | |
| print(f"[triton_cache] hydrated {n} cached artifacts from {CACHE_REPO}", flush=True) | |
| except Exception as e: | |
| print(f"[triton_cache] hydrate failed (first run?): {e}", flush=True) | |
| def teardown() -> None: | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("[triton_cache] no HF_TOKEN; skipping cache upload", flush=True) | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=token) | |
| api.upload_folder( | |
| folder_path=TRITON_CACHE_DIR, | |
| repo_id=CACHE_REPO, | |
| repo_type="dataset", | |
| commit_message="triton cache update", | |
| token=token, | |
| ) | |
| print("[triton_cache] uploaded cache to HF Hub", flush=True) | |
| except Exception as e: | |
| print(f"[triton_cache] upload failed: {e}", flush=True) | |