XAPI / deploy /huggingface /backup_manager.py
cjovs's picture
Make HF backups atomic
9d3f59d verified
from __future__ import annotations
import argparse
import os
import sys
import tarfile
import tempfile
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from typing import Iterable, Mapping
try:
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from huggingface_hub.utils import HfHubHTTPError, get_token
except ImportError: # pragma: no cover - only exercised in runtime images without the dependency
CommitOperationAdd = None # type: ignore[assignment]
CommitOperationDelete = None # type: ignore[assignment]
HfApi = None # type: ignore[assignment]
EntryNotFoundError = Exception # type: ignore[assignment]
HfHubHTTPError = Exception # type: ignore[assignment]
def get_token() -> str | None: # type: ignore[no-redef]
return os.getenv("HF_TOKEN")
def hf_hub_download(*args, **kwargs): # type: ignore[no-redef]
raise RuntimeError("huggingface_hub is required for backup download support")
class BackupArchiveError(RuntimeError):
"""Raised when a backup archive is malformed or unsafe to extract."""
def build_backup_filename(timestamp: str | None = None) -> str:
if timestamp is None:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
return f"backup-{timestamp}.tar.gz"
def select_backups_to_delete(paths: Iterable[str], keep: int = 3) -> list[str]:
backup_files = [
path
for path in paths
if PurePosixPath(path).parent == PurePosixPath("backups")
and PurePosixPath(path).name.startswith("backup-")
and PurePosixPath(path).name.endswith(".tar.gz")
]
backup_files.sort(reverse=True)
return backup_files[keep:]
def create_backup_archive(archive_path: Path, sources: Mapping[str, Path]) -> Path:
archive_path.parent.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive_path, "w:gz") as tar:
for arcname, source in sources.items():
if source.exists():
tar.add(source, arcname=arcname, recursive=True)
return archive_path
def extract_backup_archive(archive_path: Path, destination_root: Path) -> None:
destination_root.mkdir(parents=True, exist_ok=True)
destination_root = destination_root.resolve()
with tarfile.open(archive_path, "r:gz") as tar:
members = tar.getmembers()
for member in members:
pure_path = PurePosixPath(member.name)
if pure_path.is_absolute() or ".." in pure_path.parts:
raise BackupArchiveError(f"unsafe member path: {member.name}")
if member.issym() or member.islnk():
raise BackupArchiveError(f"symlinks are not allowed in backup archives: {member.name}")
target_path = (destination_root / Path(*pure_path.parts)).resolve()
try:
target_path.relative_to(destination_root)
except ValueError as exc: # pragma: no cover - defensive redundancy
raise BackupArchiveError(f"unsafe extraction target: {member.name}") from exc
tar.extractall(destination_root, filter="data")
def default_sources(data_root: Path) -> dict[str, Path]:
return {
"data": data_root / "data",
"postgres": data_root / "postgres",
"redis": data_root / "redis",
}
def data_root_has_state(data_root: Path) -> bool:
for source in default_sources(data_root).values():
if source.exists() and any(source.iterdir()):
return True
return False
def _require_hub_support() -> None:
if HfApi is None or CommitOperationAdd is None or CommitOperationDelete is None:
raise RuntimeError("huggingface_hub is required for Hugging Face backup operations")
def _resolve_token(explicit_token: str | None = None) -> str | None:
return explicit_token or os.getenv("HF_TOKEN") or get_token()
def _iter_repo_paths(api: HfApi, repo_id: str, token: str) -> list[str]:
repo_paths: list[str] = []
for item in api.list_repo_tree(
repo_id=repo_id,
path_in_repo="backups",
recursive=True,
repo_type="dataset",
token=token,
):
path = getattr(item, "path", None) or getattr(item, "rfilename", None)
if path:
repo_paths.append(path)
return repo_paths
def backup_to_dataset(
repo_id: str,
data_root: Path,
keep: int = 3,
token: str | None = None,
space_name: str | None = None,
) -> str:
_require_hub_support()
resolved_token = _resolve_token(token)
if not resolved_token:
raise RuntimeError("HF token is required to upload backups")
with tempfile.TemporaryDirectory(prefix="sub2api-backup-") as tmp_dir:
archive_path = Path(tmp_dir) / build_backup_filename()
create_backup_archive(archive_path, default_sources(data_root))
api = HfApi(token=resolved_token)
commit_suffix = f" for {space_name}" if space_name else ""
repo_paths = _iter_repo_paths(api, repo_id, resolved_token)
operations = [
CommitOperationAdd(
path_in_repo=f"backups/{archive_path.name}",
path_or_fileobj=str(archive_path),
),
CommitOperationAdd(
path_in_repo="backups/latest.tar.gz",
path_or_fileobj=str(archive_path),
),
]
for old_path in select_backups_to_delete(repo_paths, keep=keep):
operations.append(CommitOperationDelete(path_in_repo=old_path))
api.create_commit(
repo_id=repo_id,
operations=operations,
repo_type="dataset",
token=resolved_token,
commit_message=f"Update backup {archive_path.name}{commit_suffix}",
)
return archive_path.name
def restore_from_dataset(repo_id: str, data_root: Path, token: str | None = None) -> Path | None:
_require_hub_support()
resolved_token = _resolve_token(token)
if not resolved_token:
raise RuntimeError("HF token is required to restore backups")
with tempfile.TemporaryDirectory(prefix="sub2api-restore-") as tmp_dir:
try:
downloaded = hf_hub_download(
repo_id=repo_id,
filename="backups/latest.tar.gz",
repo_type="dataset",
token=resolved_token,
local_dir=tmp_dir,
force_download=True,
)
except EntryNotFoundError:
return None
except HfHubHTTPError as exc:
if getattr(exc.response, "status_code", None) == 404:
return None
raise
archive_path = Path(downloaded)
extract_backup_archive(archive_path, data_root)
return archive_path
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Manage Sub2API backups stored in a Hugging Face dataset.")
subparsers = parser.add_subparsers(dest="command", required=True)
backup_parser = subparsers.add_parser("backup", help="Create and upload a backup archive.")
backup_parser.add_argument("--repo-id", default=os.getenv("HF_BACKUP_REPO"))
backup_parser.add_argument("--data-root", default=os.getenv("DATA_ROOT", "/data"))
backup_parser.add_argument("--keep", type=int, default=int(os.getenv("BACKUP_KEEP", "3")))
backup_parser.add_argument("--token", default=None)
backup_parser.add_argument("--space-name", default=os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))
restore_parser = subparsers.add_parser("restore", help="Restore the latest backup archive.")
restore_parser.add_argument("--repo-id", default=os.getenv("HF_BACKUP_REPO"))
restore_parser.add_argument("--data-root", default=os.getenv("DATA_ROOT", "/data"))
restore_parser.add_argument("--token", default=None)
return parser
def main(argv: list[str] | None = None) -> int:
parser = _build_parser()
args = parser.parse_args(argv)
repo_id = args.repo_id
if not repo_id:
print("[backup] skipped: HF_BACKUP_REPO is not configured", file=sys.stderr)
return 0
data_root = Path(args.data_root)
if args.command == "restore":
if data_root_has_state(data_root):
print(f"[restore] skipped: {data_root} already contains data")
return 0
restored = restore_from_dataset(repo_id=repo_id, data_root=data_root, token=args.token)
if restored is None:
print("[restore] no previous backup found")
return 0
print(f"[restore] extracted {restored.name}")
return 0
backup_name = backup_to_dataset(
repo_id=repo_id,
data_root=data_root,
keep=args.keep,
token=args.token,
space_name=args.space_name,
)
print(f"[backup] uploaded {backup_name}")
return 0
if __name__ == "__main__":
raise SystemExit(main())