Spaces:
Running
Running
File size: 3,412 Bytes
db6c149 c06ed8a db6c149 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """
HuggingFace Hub integration for TFT-ASRO model persistence.
Solves the ephemeral storage problem on HF Spaces: after training,
checkpoints are uploaded to a dedicated HF model repo; before inference,
they are downloaded if not present locally.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
_HF_TOKEN_ENV = "HF_TOKEN"
_ARTIFACTS = [
"best_tft_asro.ckpt",
"pca_finbert.joblib",
"optuna_results.json",
]
def _get_token() -> Optional[str]:
return os.environ.get(_HF_TOKEN_ENV)
def upload_tft_artifacts(
local_dir: str | Path,
repo_id: str,
commit_message: str = "Update TFT-ASRO checkpoint",
) -> bool:
"""
Upload all TFT artifacts from *local_dir* to a HuggingFace model repo.
Returns True on success, False if upload fails or token is missing.
"""
token = _get_token()
if not token:
logger.warning("HF_TOKEN not set – skipping model upload to Hub")
return False
local_dir = Path(local_dir)
files_to_upload = [
local_dir / name
for name in _ARTIFACTS
if (local_dir / name).exists()
]
if not files_to_upload:
logger.warning("No TFT artifacts found in %s", local_dir)
return False
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True)
for fpath in files_to_upload:
api.upload_file(
path_or_fileobj=str(fpath),
path_in_repo=fpath.name,
repo_id=repo_id,
repo_type="model",
commit_message=commit_message,
)
logger.info("Uploaded %s → %s/%s", fpath.name, repo_id, fpath.name)
return True
except Exception as exc:
logger.error("HF Hub upload failed: %s", exc)
return False
def download_tft_artifacts(
local_dir: str | Path,
repo_id: str,
) -> bool:
"""
Download TFT artifacts from HuggingFace Hub to *local_dir*.
Skips files that already exist locally.
Returns True if at least the checkpoint was retrieved.
"""
token = _get_token()
local_dir = Path(local_dir)
local_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = local_dir / "best_tft_asro.ckpt"
if ckpt_path.exists():
logger.debug("TFT checkpoint already present locally: %s", ckpt_path)
return True
try:
from huggingface_hub import hf_hub_download
for name in _ARTIFACTS:
dest = local_dir / name
if dest.exists():
continue
try:
hf_hub_download(
repo_id=repo_id,
filename=name,
local_dir=str(local_dir),
token=token,
)
logger.info("Downloaded %s/%s → %s", repo_id, name, dest)
except Exception:
logger.debug("Artifact %s not found in %s (may not exist yet)", name, repo_id)
return ckpt_path.exists()
except ImportError:
logger.warning("huggingface_hub not installed – cannot download model")
return False
except Exception as exc:
logger.warning("HF Hub download failed: %s", exc)
return False
|