feather-runtime / overlay /triton_cache_setup.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
"""Triton cache persistence via HF Hub.
Call setup() BEFORE importing triton/mamba_ssm to hydrate the cache.
Call teardown() AFTER training to push the (possibly updated) cache.
"""
import os
from pathlib import Path
GPU_PROFILE = os.environ.get("FEATHER_GPU_PROFILE", os.environ.get("FEATHER_HF_FLAVOR", "a10g-large"))
TRITON_CACHE_DIR = os.environ.get("TRITON_CACHE_DIR", f"/workspace/triton_cache/{GPU_PROFILE}")
CACHE_REPO = os.environ.get("TRITON_CACHE_REPO", f"icarus112/feather-triton-cache-{GPU_PROFILE}")
def setup() -> None:
os.makedirs(TRITON_CACHE_DIR, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = TRITON_CACHE_DIR
token = os.environ.get("HF_TOKEN")
if not token:
print("[triton_cache] no HF_TOKEN; skipping cache hydrate", flush=True)
return
try:
from huggingface_hub import HfApi, snapshot_download, create_repo
api = HfApi(token=token)
create_repo(CACHE_REPO, repo_type="dataset", private=True, exist_ok=True, token=token)
snapshot_download(
repo_id=CACHE_REPO,
repo_type="dataset",
local_dir=TRITON_CACHE_DIR,
token=token,
)
n = sum(1 for p in Path(TRITON_CACHE_DIR).rglob("*") if p.is_file())
print(f"[triton_cache] hydrated {n} cached artifacts from {CACHE_REPO}", flush=True)
except Exception as e:
print(f"[triton_cache] hydrate failed (first run?): {e}", flush=True)
def teardown() -> None:
token = os.environ.get("HF_TOKEN")
if not token:
print("[triton_cache] no HF_TOKEN; skipping cache upload", flush=True)
return
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
api.upload_folder(
folder_path=TRITON_CACHE_DIR,
repo_id=CACHE_REPO,
repo_type="dataset",
commit_message="triton cache update",
token=token,
)
print("[triton_cache] uploaded cache to HF Hub", flush=True)
except Exception as e:
print(f"[triton_cache] upload failed: {e}", flush=True)