Raheeb Hassan
Add code + LFS attributes
398659b
"""Centralized defaults for model checkpoint/cache locations.
Goal: keep all auto-downloaded model artifacts inside this repo's `checkpoints/`
directory by default (instead of user-wide cache dirs or repo root).
"""
from __future__ import annotations
import os
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
# Hugging Face will create subfolders like `hub/`, `datasets/`, etc under HF_HOME.
HF_HOME_DIR = CHECKPOINTS_DIR / "hf"
# Torchvision uses torch.hub.load_state_dict_from_url which respects TORCH_HOME.
TORCH_HOME_DIR = CHECKPOINTS_DIR / "torch"
def ensure_default_checkpoint_dirs() -> None:
"""Ensure checkpoint dirs exist and set cache-related env vars.
This is intentionally a best-effort helper. If the user has explicitly set
env vars already, we do not override them.
"""
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
HF_HOME_DIR.mkdir(parents=True, exist_ok=True)
TORCH_HOME_DIR.mkdir(parents=True, exist_ok=True)
# Hugging Face
os.environ.setdefault("HF_HOME", str(HF_HOME_DIR))
# Compatibility env vars across transformers/huggingface-hub versions.
os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME_DIR / "hub"))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HOME_DIR / "hub"))
# Torch / torchvision
os.environ.setdefault("TORCH_HOME", str(TORCH_HOME_DIR))
def hf_cache_dir() -> Path:
ensure_default_checkpoint_dirs()
return HF_HOME_DIR
def torch_home_dir() -> Path:
ensure_default_checkpoint_dirs()
return TORCH_HOME_DIR
def checkpoints_dir() -> Path:
ensure_default_checkpoint_dirs()
return CHECKPOINTS_DIR
def default_checkpoint_path(filename: str) -> str:
"""Return an absolute path under `checkpoints/` for a given filename."""
return str(checkpoints_dir() / filename)