File size: 3,412 Bytes
db6c149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c06ed8a
db6c149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
HuggingFace Hub integration for TFT-ASRO model persistence.

Solves the ephemeral storage problem on HF Spaces: after training,
checkpoints are uploaded to a dedicated HF model repo; before inference,
they are downloaded if not present locally.
"""

from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)

_HF_TOKEN_ENV = "HF_TOKEN"

_ARTIFACTS = [
    "best_tft_asro.ckpt",
    "pca_finbert.joblib",
    "optuna_results.json",
]


def _get_token() -> Optional[str]:
    return os.environ.get(_HF_TOKEN_ENV)


def upload_tft_artifacts(
    local_dir: str | Path,
    repo_id: str,
    commit_message: str = "Update TFT-ASRO checkpoint",
) -> bool:
    """
    Upload all TFT artifacts from *local_dir* to a HuggingFace model repo.

    Returns True on success, False if upload fails or token is missing.
    """
    token = _get_token()
    if not token:
        logger.warning("HF_TOKEN not set – skipping model upload to Hub")
        return False

    local_dir = Path(local_dir)
    files_to_upload = [
        local_dir / name
        for name in _ARTIFACTS
        if (local_dir / name).exists()
    ]

    if not files_to_upload:
        logger.warning("No TFT artifacts found in %s", local_dir)
        return False

    try:
        from huggingface_hub import HfApi

        api = HfApi(token=token)
        api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True)

        for fpath in files_to_upload:
            api.upload_file(
                path_or_fileobj=str(fpath),
                path_in_repo=fpath.name,
                repo_id=repo_id,
                repo_type="model",
                commit_message=commit_message,
            )
            logger.info("Uploaded %s → %s/%s", fpath.name, repo_id, fpath.name)

        return True

    except Exception as exc:
        logger.error("HF Hub upload failed: %s", exc)
        return False


def download_tft_artifacts(
    local_dir: str | Path,
    repo_id: str,
) -> bool:
    """
    Download TFT artifacts from HuggingFace Hub to *local_dir*.

    Skips files that already exist locally.
    Returns True if at least the checkpoint was retrieved.
    """
    token = _get_token()
    local_dir = Path(local_dir)
    local_dir.mkdir(parents=True, exist_ok=True)

    ckpt_path = local_dir / "best_tft_asro.ckpt"
    if ckpt_path.exists():
        logger.debug("TFT checkpoint already present locally: %s", ckpt_path)
        return True

    try:
        from huggingface_hub import hf_hub_download

        for name in _ARTIFACTS:
            dest = local_dir / name
            if dest.exists():
                continue
            try:
                hf_hub_download(
                    repo_id=repo_id,
                    filename=name,
                    local_dir=str(local_dir),
                    token=token,
                )
                logger.info("Downloaded %s/%s → %s", repo_id, name, dest)
            except Exception:
                logger.debug("Artifact %s not found in %s (may not exist yet)", name, repo_id)

        return ckpt_path.exists()

    except ImportError:
        logger.warning("huggingface_hub not installed – cannot download model")
        return False
    except Exception as exc:
        logger.warning("HF Hub download failed: %s", exc)
        return False