File size: 2,286 Bytes
1aa7fae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
"""
Utility functions for interacting with the Hugging Face Hub for MODELS.
Used to:
- Upload the best trained model to a model repo.
- Download the registered model for inference or deployment.
"""
from pathlib import Path
from typing import Optional
import joblib
from huggingface_hub import HfApi, hf_hub_download
import config
def _get_token(explicit_token: Optional[str] = None) -> str:
token = explicit_token or config.HF_TOKEN
if not token:
raise ValueError(
"Hugging Face token is not set. "
"Set HF_TOKEN in the environment or pass token explicitly."
)
return token
def create_or_get_model_repo(
repo_id: str, token: Optional[str] = None, private: bool = False
) -> None:
"""
Create the model repo on Hugging Face Hub if it does not already exist.
"""
token = _get_token(token)
api = HfApi(token=token)
api.create_repo(
repo_id=repo_id,
repo_type="model",
private=private,
exist_ok=True,
)
def upload_model(
local_model_path: Path,
repo_id: Optional[str] = None,
repo_path: str = "model.joblib",
token: Optional[str] = None,
) -> None:
"""
Upload the trained model artifact to the Hugging Face model hub.
"""
token = _get_token(token)
repo_id = repo_id or config.HF_MODEL_REPO
api = HfApi(token=token)
create_or_get_model_repo(repo_id=repo_id, token=token)
api.upload_file(
path_or_fileobj=str(local_model_path),
path_in_repo=repo_path,
repo_id=repo_id,
repo_type="model",
)
def download_model(
repo_id: Optional[str] = None,
filename: str = "model.joblib",
token: Optional[str] = None,
local_dir: Optional[Path] = None,
):
"""
Download a model artifact from the Hugging Face model hub and load it.
"""
token = _get_token(token)
repo_id = repo_id or config.HF_MODEL_REPO
local_dir = local_dir or config.MODELS_DIR
local_dir.mkdir(parents=True, exist_ok=True)
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
token=token,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
)
return joblib.load(downloaded_path)
|