File size: 9,011 Bytes
8059bf0
 
 
 
 
 
 
 
 
 
 
 
9d3f59d
8059bf0
 
 
9d3f59d
 
8059bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3f59d
8059bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3f59d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8059bf0
9d3f59d
8059bf0
 
9d3f59d
8059bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from __future__ import annotations

import argparse
import os
import sys
import tarfile
import tempfile
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from typing import Iterable, Mapping

try:
    from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, hf_hub_download
    from huggingface_hub.errors import EntryNotFoundError
    from huggingface_hub.utils import HfHubHTTPError, get_token
except ImportError:  # pragma: no cover - only exercised in runtime images without the dependency
    CommitOperationAdd = None  # type: ignore[assignment]
    CommitOperationDelete = None  # type: ignore[assignment]
    HfApi = None  # type: ignore[assignment]
    EntryNotFoundError = Exception  # type: ignore[assignment]
    HfHubHTTPError = Exception  # type: ignore[assignment]

    def get_token() -> str | None:  # type: ignore[no-redef]
        return os.getenv("HF_TOKEN")

    def hf_hub_download(*args, **kwargs):  # type: ignore[no-redef]
        raise RuntimeError("huggingface_hub is required for backup download support")


class BackupArchiveError(RuntimeError):
    """Raised when a backup archive is malformed or unsafe to extract."""


def build_backup_filename(timestamp: str | None = None) -> str:
    if timestamp is None:
        timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    return f"backup-{timestamp}.tar.gz"


def select_backups_to_delete(paths: Iterable[str], keep: int = 3) -> list[str]:
    backup_files = [
        path
        for path in paths
        if PurePosixPath(path).parent == PurePosixPath("backups")
        and PurePosixPath(path).name.startswith("backup-")
        and PurePosixPath(path).name.endswith(".tar.gz")
    ]
    backup_files.sort(reverse=True)
    return backup_files[keep:]


def create_backup_archive(archive_path: Path, sources: Mapping[str, Path]) -> Path:
    archive_path.parent.mkdir(parents=True, exist_ok=True)
    with tarfile.open(archive_path, "w:gz") as tar:
        for arcname, source in sources.items():
            if source.exists():
                tar.add(source, arcname=arcname, recursive=True)
    return archive_path


def extract_backup_archive(archive_path: Path, destination_root: Path) -> None:
    destination_root.mkdir(parents=True, exist_ok=True)
    destination_root = destination_root.resolve()
    with tarfile.open(archive_path, "r:gz") as tar:
        members = tar.getmembers()
        for member in members:
            pure_path = PurePosixPath(member.name)
            if pure_path.is_absolute() or ".." in pure_path.parts:
                raise BackupArchiveError(f"unsafe member path: {member.name}")
            if member.issym() or member.islnk():
                raise BackupArchiveError(f"symlinks are not allowed in backup archives: {member.name}")

            target_path = (destination_root / Path(*pure_path.parts)).resolve()
            try:
                target_path.relative_to(destination_root)
            except ValueError as exc:  # pragma: no cover - defensive redundancy
                raise BackupArchiveError(f"unsafe extraction target: {member.name}") from exc

        tar.extractall(destination_root, filter="data")


def default_sources(data_root: Path) -> dict[str, Path]:
    return {
        "data": data_root / "data",
        "postgres": data_root / "postgres",
        "redis": data_root / "redis",
    }


def data_root_has_state(data_root: Path) -> bool:
    for source in default_sources(data_root).values():
        if source.exists() and any(source.iterdir()):
            return True
    return False


def _require_hub_support() -> None:
    if HfApi is None or CommitOperationAdd is None or CommitOperationDelete is None:
        raise RuntimeError("huggingface_hub is required for Hugging Face backup operations")


def _resolve_token(explicit_token: str | None = None) -> str | None:
    return explicit_token or os.getenv("HF_TOKEN") or get_token()


def _iter_repo_paths(api: HfApi, repo_id: str, token: str) -> list[str]:
    repo_paths: list[str] = []
    for item in api.list_repo_tree(
        repo_id=repo_id,
        path_in_repo="backups",
        recursive=True,
        repo_type="dataset",
        token=token,
    ):
        path = getattr(item, "path", None) or getattr(item, "rfilename", None)
        if path:
            repo_paths.append(path)
    return repo_paths


def backup_to_dataset(
    repo_id: str,
    data_root: Path,
    keep: int = 3,
    token: str | None = None,
    space_name: str | None = None,
) -> str:
    _require_hub_support()
    resolved_token = _resolve_token(token)
    if not resolved_token:
        raise RuntimeError("HF token is required to upload backups")

    with tempfile.TemporaryDirectory(prefix="sub2api-backup-") as tmp_dir:
        archive_path = Path(tmp_dir) / build_backup_filename()
        create_backup_archive(archive_path, default_sources(data_root))

        api = HfApi(token=resolved_token)
        commit_suffix = f" for {space_name}" if space_name else ""
        repo_paths = _iter_repo_paths(api, repo_id, resolved_token)
        operations = [
            CommitOperationAdd(
                path_in_repo=f"backups/{archive_path.name}",
                path_or_fileobj=str(archive_path),
            ),
            CommitOperationAdd(
                path_in_repo="backups/latest.tar.gz",
                path_or_fileobj=str(archive_path),
            ),
        ]
        for old_path in select_backups_to_delete(repo_paths, keep=keep):
            operations.append(CommitOperationDelete(path_in_repo=old_path))

        api.create_commit(
            repo_id=repo_id,
            operations=operations,
            repo_type="dataset",
            token=resolved_token,
            commit_message=f"Update backup {archive_path.name}{commit_suffix}",
        )

        return archive_path.name


def restore_from_dataset(repo_id: str, data_root: Path, token: str | None = None) -> Path | None:
    _require_hub_support()
    resolved_token = _resolve_token(token)
    if not resolved_token:
        raise RuntimeError("HF token is required to restore backups")

    with tempfile.TemporaryDirectory(prefix="sub2api-restore-") as tmp_dir:
        try:
            downloaded = hf_hub_download(
                repo_id=repo_id,
                filename="backups/latest.tar.gz",
                repo_type="dataset",
                token=resolved_token,
                local_dir=tmp_dir,
                force_download=True,
            )
        except EntryNotFoundError:
            return None
        except HfHubHTTPError as exc:
            if getattr(exc.response, "status_code", None) == 404:
                return None
            raise

        archive_path = Path(downloaded)
        extract_backup_archive(archive_path, data_root)
        return archive_path


def _build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Manage Sub2API backups stored in a Hugging Face dataset.")
    subparsers = parser.add_subparsers(dest="command", required=True)

    backup_parser = subparsers.add_parser("backup", help="Create and upload a backup archive.")
    backup_parser.add_argument("--repo-id", default=os.getenv("HF_BACKUP_REPO"))
    backup_parser.add_argument("--data-root", default=os.getenv("DATA_ROOT", "/data"))
    backup_parser.add_argument("--keep", type=int, default=int(os.getenv("BACKUP_KEEP", "3")))
    backup_parser.add_argument("--token", default=None)
    backup_parser.add_argument("--space-name", default=os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))

    restore_parser = subparsers.add_parser("restore", help="Restore the latest backup archive.")
    restore_parser.add_argument("--repo-id", default=os.getenv("HF_BACKUP_REPO"))
    restore_parser.add_argument("--data-root", default=os.getenv("DATA_ROOT", "/data"))
    restore_parser.add_argument("--token", default=None)

    return parser


def main(argv: list[str] | None = None) -> int:
    parser = _build_parser()
    args = parser.parse_args(argv)

    repo_id = args.repo_id
    if not repo_id:
        print("[backup] skipped: HF_BACKUP_REPO is not configured", file=sys.stderr)
        return 0

    data_root = Path(args.data_root)

    if args.command == "restore":
        if data_root_has_state(data_root):
            print(f"[restore] skipped: {data_root} already contains data")
            return 0

        restored = restore_from_dataset(repo_id=repo_id, data_root=data_root, token=args.token)
        if restored is None:
            print("[restore] no previous backup found")
            return 0
        print(f"[restore] extracted {restored.name}")
        return 0

    backup_name = backup_to_dataset(
        repo_id=repo_id,
        data_root=data_root,
        keep=args.keep,
        token=args.token,
        space_name=args.space_name,
    )
    print(f"[backup] uploaded {backup_name}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())