"""Triton cache persistence via HF Hub. Call setup() BEFORE importing triton/mamba_ssm to hydrate the cache. Call teardown() AFTER training to push the (possibly updated) cache. """ import os from pathlib import Path GPU_PROFILE = os.environ.get("FEATHER_GPU_PROFILE", os.environ.get("FEATHER_HF_FLAVOR", "a10g-large")) TRITON_CACHE_DIR = os.environ.get("TRITON_CACHE_DIR", f"/workspace/triton_cache/{GPU_PROFILE}") CACHE_REPO = os.environ.get("TRITON_CACHE_REPO", f"icarus112/feather-triton-cache-{GPU_PROFILE}") def setup() -> None: os.makedirs(TRITON_CACHE_DIR, exist_ok=True) os.environ["TRITON_CACHE_DIR"] = TRITON_CACHE_DIR token = os.environ.get("HF_TOKEN") if not token: print("[triton_cache] no HF_TOKEN; skipping cache hydrate", flush=True) return try: from huggingface_hub import HfApi, snapshot_download, create_repo api = HfApi(token=token) create_repo(CACHE_REPO, repo_type="dataset", private=True, exist_ok=True, token=token) snapshot_download( repo_id=CACHE_REPO, repo_type="dataset", local_dir=TRITON_CACHE_DIR, token=token, ) n = sum(1 for p in Path(TRITON_CACHE_DIR).rglob("*") if p.is_file()) print(f"[triton_cache] hydrated {n} cached artifacts from {CACHE_REPO}", flush=True) except Exception as e: print(f"[triton_cache] hydrate failed (first run?): {e}", flush=True) def teardown() -> None: token = os.environ.get("HF_TOKEN") if not token: print("[triton_cache] no HF_TOKEN; skipping cache upload", flush=True) return try: from huggingface_hub import HfApi api = HfApi(token=token) api.upload_folder( folder_path=TRITON_CACHE_DIR, repo_id=CACHE_REPO, repo_type="dataset", commit_message="triton cache update", token=token, ) print("[triton_cache] uploaded cache to HF Hub", flush=True) except Exception as e: print(f"[triton_cache] upload failed: {e}", flush=True)