Spaces:
Sleeping
Sleeping
| """hf_persistence.py — Bulletproof Hugging Face Hub persistence for DAHS_2. | |
| Why this module exists | |
| ---------------------- | |
| Two prior HF Space runs lost every artifact when the runtime terminated. The | |
| fix is a layered, redundant uploader: | |
| 1. Incremental: every pipeline step (data gen, each model, evaluation) | |
| calls ``persistor.snapshot(folder)`` immediately after writing files. | |
| 2. Periodic: a background thread re-uploads the full ``data/``, ``models/``, | |
| ``results/`` tree every N seconds so even mid-step crashes lose at most | |
| one period of work. | |
| 3. Terminal: an ``atexit`` handler and a ``SIGTERM`` handler do a final | |
| full upload before the process dies. HF Spaces send SIGTERM on pause / | |
| hardware reclaim, so this is the path that catches "runtime ended" | |
| deletions. | |
| 4. Resilient: every ``api.upload_folder`` call is retried with exponential | |
| backoff and is wrapped so a transient Hub error never stops the run. | |
| Public API | |
| ---------- | |
| HubPersistor(repo_id, token, folders=("data", "models", "results")) | |
| .snapshot(folder=None, msg=None) # one-shot upload | |
| .start_periodic(interval_seconds=300) # background uploader thread | |
| .stop_periodic() | |
| .install_signal_handlers() # SIGTERM/SIGINT -> final upload | |
| .install_atexit() # final upload at interpreter exit | |
| """ | |
| from __future__ import annotations | |
| import atexit | |
| import logging | |
| import os | |
| import signal | |
| import threading | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_FOLDERS = ("data", "models", "results", "logs") | |
| class HubPersistor: | |
| """Layered, retry-armoured uploader to a Hugging Face model repo.""" | |
| def __init__( | |
| self, | |
| repo_id: str, | |
| token: Optional[str] = None, | |
| folders: Iterable[str] = DEFAULT_FOLDERS, | |
| repo_type: str = "model", | |
| max_retries: int = 4, | |
| retry_base_delay: float = 2.0, | |
| ) -> None: | |
| from huggingface_hub import HfApi, login | |
| self.repo_id = repo_id | |
| self.repo_type = repo_type | |
| self.folders = tuple(folders) | |
| self.max_retries = max_retries | |
| self.retry_base_delay = retry_base_delay | |
| if token: | |
| try: | |
| login(token=token, add_to_git_credential=False) | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning("hf login() raised %s — proceeding with HfApi(token=...)", e) | |
| self.api = HfApi(token=token) if token else HfApi() | |
| try: | |
| self.api.create_repo( | |
| repo_id=repo_id, repo_type=repo_type, exist_ok=True | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| # We don't raise here: the caller may want to keep running locally | |
| # even if the Hub is unreachable. Subsequent uploads will retry. | |
| logger.error("create_repo(%s) failed: %s", repo_id, e) | |
| self._lock = threading.Lock() | |
| self._stop = threading.Event() | |
| self._thread: Optional[threading.Thread] = None | |
| self._signals_installed = False | |
| self._atexit_installed = False | |
| self._last_upload_ts: float = 0.0 | |
| # ------------------------------------------------------------------ | |
| # Core upload | |
| # ------------------------------------------------------------------ | |
| def snapshot(self, folder: Optional[str] = None, msg: Optional[str] = None) -> bool: | |
| """Upload one folder (or all configured folders). Never raises.""" | |
| targets = (folder,) if folder else self.folders | |
| commit_msg = msg or f"DAHS_2 snapshot {datetime.now(timezone.utc).isoformat()}" | |
| any_ok = False | |
| with self._lock: | |
| for f in targets: | |
| if not f or not Path(f).exists(): | |
| continue | |
| ok = self._upload_with_retry(f, commit_msg) | |
| any_ok = any_ok or ok | |
| self._last_upload_ts = time.time() | |
| return any_ok | |
| def _upload_with_retry(self, folder: str, commit_msg: str) -> bool: | |
| delay = self.retry_base_delay | |
| for attempt in range(1, self.max_retries + 1): | |
| try: | |
| self.api.upload_folder( | |
| folder_path=folder, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| path_in_repo=folder, | |
| commit_message=f"{commit_msg} [{folder}]", | |
| ) | |
| logger.info("[hub] uploaded %s/ -> %s", folder, self.repo_id) | |
| return True | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning( | |
| "[hub] upload %s/ attempt %d/%d failed: %s", | |
| folder, attempt, self.max_retries, e, | |
| ) | |
| if attempt == self.max_retries: | |
| return False | |
| time.sleep(delay) | |
| delay *= 2 | |
| return False | |
| # ------------------------------------------------------------------ | |
| # Single-file upload (fast path for tiny artifacts) | |
| # ------------------------------------------------------------------ | |
| def upload_file(self, local_path: str, path_in_repo: Optional[str] = None) -> bool: | |
| if not Path(local_path).exists(): | |
| return False | |
| target = path_in_repo or local_path | |
| for attempt in range(1, self.max_retries + 1): | |
| try: | |
| self.api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=target, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| commit_message=f"upload {target}", | |
| ) | |
| logger.info("[hub] uploaded file %s", target) | |
| return True | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning("[hub] upload_file %s attempt %d failed: %s", target, attempt, e) | |
| if attempt == self.max_retries: | |
| return False | |
| time.sleep(self.retry_base_delay * attempt) | |
| return False | |
| # ------------------------------------------------------------------ | |
| # Background periodic uploader | |
| # ------------------------------------------------------------------ | |
| def start_periodic(self, interval_seconds: int = 300) -> None: | |
| if self._thread is not None and self._thread.is_alive(): | |
| return | |
| self._stop.clear() | |
| def _loop() -> None: | |
| logger.info("[hub] periodic uploader started (every %ds)", interval_seconds) | |
| while not self._stop.wait(interval_seconds): | |
| try: | |
| self.snapshot(msg="periodic") | |
| except Exception as e: # noqa: BLE001 | |
| logger.warning("[hub] periodic snapshot raised: %s", e) | |
| logger.info("[hub] periodic uploader stopped") | |
| self._thread = threading.Thread(target=_loop, name="HubPersistor", daemon=True) | |
| self._thread.start() | |
| def stop_periodic(self) -> None: | |
| self._stop.set() | |
| if self._thread is not None: | |
| self._thread.join(timeout=10) | |
| # ------------------------------------------------------------------ | |
| # Terminal handlers | |
| # ------------------------------------------------------------------ | |
| def install_atexit(self) -> None: | |
| if self._atexit_installed: | |
| return | |
| atexit.register(self._final_upload, "atexit") | |
| self._atexit_installed = True | |
| def install_signal_handlers(self) -> None: | |
| if self._signals_installed: | |
| return | |
| def _handler(signum, frame): # noqa: ARG001 | |
| logger.warning("[hub] signal %s received — final upload then exit", signum) | |
| self._final_upload(f"signal_{signum}") | |
| os._exit(0) # bypass other atexit hooks; we already saved | |
| for sig in (signal.SIGTERM, signal.SIGINT): | |
| try: | |
| signal.signal(sig, _handler) | |
| except (ValueError, OSError): | |
| # Not running in main thread (some HF runners) — ignore. | |
| pass | |
| self._signals_installed = True | |
| def _final_upload(self, reason: str) -> None: | |
| try: | |
| logger.info("[hub] final upload triggered by %s", reason) | |
| self.stop_periodic() | |
| self.snapshot(msg=f"final-{reason}") | |
| except Exception as e: # noqa: BLE001 | |
| logger.error("[hub] final upload failed: %s", e) | |
| # --------------------------------------------------------------------------- | |
| # Helper: build a persistor from environment, or return a no-op stub. | |
| # --------------------------------------------------------------------------- | |
| class _NullPersistor: | |
| """Drop-in replacement when no HF credentials are configured.""" | |
| def snapshot(self, *args, **kwargs) -> bool: # noqa: D401, ARG002 | |
| return False | |
| def upload_file(self, *args, **kwargs) -> bool: # noqa: ARG002 | |
| return False | |
| def start_periodic(self, *args, **kwargs) -> None: # noqa: ARG002 | |
| return None | |
| def stop_periodic(self) -> None: | |
| return None | |
| def install_atexit(self) -> None: | |
| return None | |
| def install_signal_handlers(self) -> None: | |
| return None | |
| def from_env(require: bool = False): | |
| """Build a HubPersistor from HF_TOKEN + REPO_ID env vars. | |
| If ``require`` is False and either var is missing, returns a NullPersistor | |
| so callers can use the API unconditionally during local runs. | |
| """ | |
| token = os.environ.get("HF_TOKEN") | |
| repo_id = os.environ.get("REPO_ID") | |
| if not token or not repo_id: | |
| if require: | |
| raise RuntimeError("HF_TOKEN and REPO_ID env vars are required.") | |
| logger.info("[hub] HF_TOKEN/REPO_ID not set — Hub persistence disabled.") | |
| return _NullPersistor() | |
| return HubPersistor(repo_id=repo_id, token=token) | |