feather-a10-runtime / overlay /scripts /benchmark_assets.py
Jackoatmon's picture
Update Feather training runtime image
951f760 verified
#!/usr/bin/env python3
from __future__ import annotations
import os
from pathlib import Path
def _download_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> Path:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=repo_id,
repo_type="model",
filename=filename,
subfolder=subfolder,
token=token,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
return Path(path)
def resolve_tokenizer_cache_repo(*, output_repo: str, retina_cache_repo: str) -> str:
return (
os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
or os.environ.get("FEATHER_HF_OUTPUT_REPO")
or os.environ.get("HF_REPO_ID")
or os.environ.get("HYDRA_RETINA_CACHE_REPO")
or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
or output_repo
or retina_cache_repo
)
def tokenizer_cache_prefix() -> str:
vocab_size = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536"))
return f"tokenizer/vocab{vocab_size}"
def hydrate_benchmark_assets(*, cache_dir: Path, output_repo: str, tokenizer_repo: str, token: str | None) -> dict[str, str]:
cache_dir.mkdir(parents=True, exist_ok=True)
tok_dir = cache_dir / "tokenizer"
tok_dir.mkdir(parents=True, exist_ok=True)
tok_repo = resolve_tokenizer_cache_repo(output_repo=tokenizer_repo, retina_cache_repo=tokenizer_repo)
tok_prefix = tokenizer_cache_prefix()
ckpt_path = cache_dir / "best_bpb.pt"
if not ckpt_path.exists():
ckpt_path = _download_file(repo_id=output_repo, filename="best_bpb.pt", local_dir=str(cache_dir), token=token)
tok_path = tok_dir / "tokenizer.pkl"
if not tok_path.exists():
tok_path = _download_file(repo_id=tok_repo, filename="tokenizer.pkl", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
token_bytes_path = tok_dir / "token_bytes.pt"
if not token_bytes_path.exists():
token_bytes_path = _download_file(repo_id=tok_repo, filename="token_bytes.pt", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
return {
"checkpoint_path": str(ckpt_path),
"tokenizer_dir": str(tok_dir),
}