ifieryarrows's picture
Sync from GitHub (tests passed)
c06ed8a verified
"""
HuggingFace Hub integration for TFT-ASRO model persistence.
Solves the ephemeral storage problem on HF Spaces: after training,
checkpoints are uploaded to a dedicated HF model repo; before inference,
they are downloaded if not present locally.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
_HF_TOKEN_ENV = "HF_TOKEN"
_ARTIFACTS = [
"best_tft_asro.ckpt",
"pca_finbert.joblib",
"optuna_results.json",
]
def _get_token() -> Optional[str]:
return os.environ.get(_HF_TOKEN_ENV)
def upload_tft_artifacts(
local_dir: str | Path,
repo_id: str,
commit_message: str = "Update TFT-ASRO checkpoint",
) -> bool:
"""
Upload all TFT artifacts from *local_dir* to a HuggingFace model repo.
Returns True on success, False if upload fails or token is missing.
"""
token = _get_token()
if not token:
logger.warning("HF_TOKEN not set – skipping model upload to Hub")
return False
local_dir = Path(local_dir)
files_to_upload = [
local_dir / name
for name in _ARTIFACTS
if (local_dir / name).exists()
]
if not files_to_upload:
logger.warning("No TFT artifacts found in %s", local_dir)
return False
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True)
for fpath in files_to_upload:
api.upload_file(
path_or_fileobj=str(fpath),
path_in_repo=fpath.name,
repo_id=repo_id,
repo_type="model",
commit_message=commit_message,
)
logger.info("Uploaded %s β†’ %s/%s", fpath.name, repo_id, fpath.name)
return True
except Exception as exc:
logger.error("HF Hub upload failed: %s", exc)
return False
def download_tft_artifacts(
local_dir: str | Path,
repo_id: str,
) -> bool:
"""
Download TFT artifacts from HuggingFace Hub to *local_dir*.
Skips files that already exist locally.
Returns True if at least the checkpoint was retrieved.
"""
token = _get_token()
local_dir = Path(local_dir)
local_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = local_dir / "best_tft_asro.ckpt"
if ckpt_path.exists():
logger.debug("TFT checkpoint already present locally: %s", ckpt_path)
return True
try:
from huggingface_hub import hf_hub_download
for name in _ARTIFACTS:
dest = local_dir / name
if dest.exists():
continue
try:
hf_hub_download(
repo_id=repo_id,
filename=name,
local_dir=str(local_dir),
token=token,
)
logger.info("Downloaded %s/%s β†’ %s", repo_id, name, dest)
except Exception:
logger.debug("Artifact %s not found in %s (may not exist yet)", name, repo_id)
return ckpt_path.exists()
except ImportError:
logger.warning("huggingface_hub not installed – cannot download model")
return False
except Exception as exc:
logger.warning("HF Hub download failed: %s", exc)
return False