cc / tools /hf_backup.py
hequ's picture
Update tools/hf_backup.py
9abffe0 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import sys
from pathlib import Path
from huggingface_hub import HfApi
RESTORE_DIR = os.environ.get("HF_RESTORE_DIR", "/tmp/crs_backup") # 与 shell 的 TMP_DIR 对齐
def list_backups(api: HfApi, repo_id: str, prefix: str):
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
backs = [f for f in files if f.startswith(prefix) and f.endswith(".tar.gz")]
backs.sort()
return backs
def ensure_dataset(api: HfApi, repo_id: str):
try:
api.dataset_info(repo_id=repo_id)
except Exception:
# 若不存在则创建为私有数据集;如需公开可改 private=False
api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
def upload(args):
api = HfApi(token=args.token)
ensure_dataset(api, args.repo)
# 上传当前归档
api.upload_file(
path_or_fileobj=args.file,
path_in_repo=os.path.basename(args.file),
repo_id=args.repo,
repo_type="dataset",
)
# 仅保留最新 N 份
if args.max and args.max > 0:
backs = list_backups(api, args.repo, args.prefix)
if len(backs) > args.max:
to_del = backs[: len(backs) - args.max]
for f in to_del:
try:
api.delete_file(path_in_repo=f, repo_id=args.repo, repo_type="dataset")
except Exception:
pass # 删除失败不致命
def restore(args):
api = HfApi(token=args.token)
backs = list_backups(api, args.repo, args.prefix)
if not backs:
return # 无备份,安静退出,由 shell 打日志
latest = backs[-1]
# 用持久目录而不是临时目录(否则 Python 退出就删了)
Path(RESTORE_DIR).mkdir(parents=True, exist_ok=True)
path = api.hf_hub_download(
repo_id=args.repo,
filename=latest,
repo_type="dataset",
local_dir=RESTORE_DIR,
local_dir_use_symlinks=False, # 更稳妥,避免符号链接指向被清理
)
print(path) # 仅输出路径,供 shell 脚本接收
def main():
p = argparse.ArgumentParser()
sub = p.add_subparsers(dest="cmd", required=True)
up = sub.add_parser("upload")
up.add_argument("--token", required=True)
up.add_argument("--repo", required=True)
up.add_argument("--file", required=True)
up.add_argument("--prefix", required=True)
up.add_argument("--max", type=int, default=10)
up.set_defaults(func=upload)
rs = sub.add_parser("restore")
rs.add_argument("--token", required=True)
rs.add_argument("--repo", required=True)
rs.add_argument("--prefix", required=True)
rs.set_defaults(func=restore)
args = p.parse_args()
try:
args.func(args)
except KeyboardInterrupt:
sys.exit(130)
if __name__ == "__main__":
main()