File size: 2,087 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Triton cache persistence via HF Hub.

Call setup() BEFORE importing triton/mamba_ssm to hydrate the cache.
Call teardown() AFTER training to push the (possibly updated) cache.
"""
import os
from pathlib import Path

GPU_PROFILE = os.environ.get("FEATHER_GPU_PROFILE", os.environ.get("FEATHER_HF_FLAVOR", "a10g-large"))
TRITON_CACHE_DIR = os.environ.get("TRITON_CACHE_DIR", f"/workspace/triton_cache/{GPU_PROFILE}")
CACHE_REPO = os.environ.get("TRITON_CACHE_REPO", f"icarus112/feather-triton-cache-{GPU_PROFILE}")


def setup() -> None:
    os.makedirs(TRITON_CACHE_DIR, exist_ok=True)
    os.environ["TRITON_CACHE_DIR"] = TRITON_CACHE_DIR
    token = os.environ.get("HF_TOKEN")
    if not token:
        print("[triton_cache] no HF_TOKEN; skipping cache hydrate", flush=True)
        return
    try:
        from huggingface_hub import HfApi, snapshot_download, create_repo
        api = HfApi(token=token)
        create_repo(CACHE_REPO, repo_type="dataset", private=True, exist_ok=True, token=token)
        snapshot_download(
            repo_id=CACHE_REPO,
            repo_type="dataset",
            local_dir=TRITON_CACHE_DIR,
            token=token,
        )
        n = sum(1 for p in Path(TRITON_CACHE_DIR).rglob("*") if p.is_file())
        print(f"[triton_cache] hydrated {n} cached artifacts from {CACHE_REPO}", flush=True)
    except Exception as e:
        print(f"[triton_cache] hydrate failed (first run?): {e}", flush=True)


def teardown() -> None:
    token = os.environ.get("HF_TOKEN")
    if not token:
        print("[triton_cache] no HF_TOKEN; skipping cache upload", flush=True)
        return
    try:
        from huggingface_hub import HfApi
        api = HfApi(token=token)
        api.upload_folder(
            folder_path=TRITON_CACHE_DIR,
            repo_id=CACHE_REPO,
            repo_type="dataset",
            commit_message="triton cache update",
            token=token,
        )
        print("[triton_cache] uploaded cache to HF Hub", flush=True)
    except Exception as e:
        print(f"[triton_cache] upload failed: {e}", flush=True)