| | """Centralized defaults for model checkpoint/cache locations. |
| | |
| | Goal: keep all auto-downloaded model artifacts inside this repo's `checkpoints/` |
| | directory by default (instead of user-wide cache dirs or repo root). |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | from pathlib import Path |
| |
|
| |
|
| | PROJECT_ROOT = Path(__file__).resolve().parent |
| | CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints" |
| |
|
| | |
| | HF_HOME_DIR = CHECKPOINTS_DIR / "hf" |
| |
|
| | |
| | TORCH_HOME_DIR = CHECKPOINTS_DIR / "torch" |
| |
|
| |
|
| | def ensure_default_checkpoint_dirs() -> None: |
| | """Ensure checkpoint dirs exist and set cache-related env vars. |
| | |
| | This is intentionally a best-effort helper. If the user has explicitly set |
| | env vars already, we do not override them. |
| | """ |
| |
|
| | CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True) |
| | HF_HOME_DIR.mkdir(parents=True, exist_ok=True) |
| | TORCH_HOME_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | os.environ.setdefault("HF_HOME", str(HF_HOME_DIR)) |
| | |
| | os.environ.setdefault("TRANSFORMERS_CACHE", str(HF_HOME_DIR / "hub")) |
| | os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HOME_DIR / "hub")) |
| |
|
| | |
| | os.environ.setdefault("TORCH_HOME", str(TORCH_HOME_DIR)) |
| |
|
| |
|
| | def hf_cache_dir() -> Path: |
| | ensure_default_checkpoint_dirs() |
| | return HF_HOME_DIR |
| |
|
| |
|
| | def torch_home_dir() -> Path: |
| | ensure_default_checkpoint_dirs() |
| | return TORCH_HOME_DIR |
| |
|
| |
|
| | def checkpoints_dir() -> Path: |
| | ensure_default_checkpoint_dirs() |
| | return CHECKPOINTS_DIR |
| |
|
| |
|
| | def default_checkpoint_path(filename: str) -> str: |
| | """Return an absolute path under `checkpoints/` for a given filename.""" |
| |
|
| | return str(checkpoints_dir() / filename) |
| |
|