File size: 2,286 Bytes
1aa7fae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Utility functions for interacting with the Hugging Face Hub for MODELS.

Used to:
- Upload the best trained model to a model repo.
- Download the registered model for inference or deployment.
"""

from pathlib import Path
from typing import Optional

import joblib
from huggingface_hub import HfApi, hf_hub_download

import config


def _get_token(explicit_token: Optional[str] = None) -> str:
    token = explicit_token or config.HF_TOKEN
    if not token:
        raise ValueError(
            "Hugging Face token is not set. "
            "Set HF_TOKEN in the environment or pass token explicitly."
        )
    return token


def create_or_get_model_repo(
    repo_id: str, token: Optional[str] = None, private: bool = False
) -> None:
    """
    Create the model repo on Hugging Face Hub if it does not already exist.
    """
    token = _get_token(token)
    api = HfApi(token=token)
    api.create_repo(
        repo_id=repo_id,
        repo_type="model",
        private=private,
        exist_ok=True,
    )


def upload_model(
    local_model_path: Path,
    repo_id: Optional[str] = None,
    repo_path: str = "model.joblib",
    token: Optional[str] = None,
) -> None:
    """
    Upload the trained model artifact to the Hugging Face model hub.
    """
    token = _get_token(token)
    repo_id = repo_id or config.HF_MODEL_REPO

    api = HfApi(token=token)
    create_or_get_model_repo(repo_id=repo_id, token=token)

    api.upload_file(
        path_or_fileobj=str(local_model_path),
        path_in_repo=repo_path,
        repo_id=repo_id,
        repo_type="model",
    )


def download_model(
    repo_id: Optional[str] = None,
    filename: str = "model.joblib",
    token: Optional[str] = None,
    local_dir: Optional[Path] = None,
):
    """
    Download a model artifact from the Hugging Face model hub and load it.
    """
    token = _get_token(token)
    repo_id = repo_id or config.HF_MODEL_REPO
    local_dir = local_dir or config.MODELS_DIR
    local_dir.mkdir(parents=True, exist_ok=True)

    downloaded_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type="model",
        token=token,
        local_dir=str(local_dir),
        local_dir_use_symlinks=False,
    )

    return joblib.load(downloaded_path)