Spaces:
Running
Running
| """ | |
| HuggingFace Hub integration for TFT-ASRO model persistence. | |
| Solves the ephemeral storage problem on HF Spaces: after training, | |
| checkpoints are uploaded to a dedicated HF model repo; before inference, | |
| they are downloaded if not present locally. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| _HF_TOKEN_ENV = "HF_TOKEN" | |
| _ARTIFACTS = [ | |
| "best_tft_asro.ckpt", | |
| "pca_finbert.joblib", | |
| "optuna_results.json", | |
| ] | |
| def _get_token() -> Optional[str]: | |
| return os.environ.get(_HF_TOKEN_ENV) | |
| def upload_tft_artifacts( | |
| local_dir: str | Path, | |
| repo_id: str, | |
| commit_message: str = "Update TFT-ASRO checkpoint", | |
| ) -> bool: | |
| """ | |
| Upload all TFT artifacts from *local_dir* to a HuggingFace model repo. | |
| Returns True on success, False if upload fails or token is missing. | |
| """ | |
| token = _get_token() | |
| if not token: | |
| logger.warning("HF_TOKEN not set β skipping model upload to Hub") | |
| return False | |
| local_dir = Path(local_dir) | |
| files_to_upload = [ | |
| local_dir / name | |
| for name in _ARTIFACTS | |
| if (local_dir / name).exists() | |
| ] | |
| if not files_to_upload: | |
| logger.warning("No TFT artifacts found in %s", local_dir) | |
| return False | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True) | |
| for fpath in files_to_upload: | |
| api.upload_file( | |
| path_or_fileobj=str(fpath), | |
| path_in_repo=fpath.name, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message=commit_message, | |
| ) | |
| logger.info("Uploaded %s β %s/%s", fpath.name, repo_id, fpath.name) | |
| return True | |
| except Exception as exc: | |
| logger.error("HF Hub upload failed: %s", exc) | |
| return False | |
| def download_tft_artifacts( | |
| local_dir: str | Path, | |
| repo_id: str, | |
| ) -> bool: | |
| """ | |
| Download TFT artifacts from HuggingFace Hub to *local_dir*. | |
| Skips files that already exist locally. | |
| Returns True if at least the checkpoint was retrieved. | |
| """ | |
| token = _get_token() | |
| local_dir = Path(local_dir) | |
| local_dir.mkdir(parents=True, exist_ok=True) | |
| ckpt_path = local_dir / "best_tft_asro.ckpt" | |
| if ckpt_path.exists(): | |
| logger.debug("TFT checkpoint already present locally: %s", ckpt_path) | |
| return True | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| for name in _ARTIFACTS: | |
| dest = local_dir / name | |
| if dest.exists(): | |
| continue | |
| try: | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=name, | |
| local_dir=str(local_dir), | |
| token=token, | |
| ) | |
| logger.info("Downloaded %s/%s β %s", repo_id, name, dest) | |
| except Exception: | |
| logger.debug("Artifact %s not found in %s (may not exist yet)", name, repo_id) | |
| return ckpt_path.exists() | |
| except ImportError: | |
| logger.warning("huggingface_hub not installed β cannot download model") | |
| return False | |
| except Exception as exc: | |
| logger.warning("HF Hub download failed: %s", exc) | |
| return False | |