"""Centralized defaults for model checkpoint/cache locations. Goal: keep all auto-downloaded model artifacts inside this repo's `checkpoints/` directory by default (instead of user-wide cache dirs or repo root). """ from __future__ import annotations import os from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parent CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints" # Hugging Face will create subfolders like `hub/`, `datasets/`, etc under HF_HOME. HF_HOME_DIR = CHECKPOINTS_DIR / "hf" # Torchvision uses torch.hub.load_state_dict_from_url which respects TORCH_HOME. TORCH_HOME_DIR = CHECKPOINTS_DIR / "torch" def ensure_default_checkpoint_dirs() -> None: """Ensure checkpoint dirs exist and set cache-related env vars. This is intentionally a best-effort helper. If the user has explicitly set env vars already, we do not override them. """ CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True) HF_HOME_DIR.mkdir(parents=True, exist_ok=True) TORCH_HOME_DIR.mkdir(parents=True, exist_ok=True) # Hugging Face os.environ.setdefault("HF_HOME", str(HF_HOME_DIR)) # Compatibility env vars across transformers/huggingface-hub versions. os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME_DIR / "hub")) os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HOME_DIR / "hub")) # Torch / torchvision os.environ.setdefault("TORCH_HOME", str(TORCH_HOME_DIR)) def hf_cache_dir() -> Path: ensure_default_checkpoint_dirs() return HF_HOME_DIR def torch_home_dir() -> Path: ensure_default_checkpoint_dirs() return TORCH_HOME_DIR def checkpoints_dir() -> Path: ensure_default_checkpoint_dirs() return CHECKPOINTS_DIR def default_checkpoint_path(filename: str) -> str: """Return an absolute path under `checkpoints/` for a given filename.""" return str(checkpoints_dir() / filename)