File size: 2,299 Bytes
4f54dc8
 
 
 
 
 
140619f
4f54dc8
 
 
140619f
 
 
 
 
 
 
 
 
 
 
 
 
 
4f54dc8
 
140619f
 
4f54dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
"""
从 Hugging Face Dataset 仓库恢复 OpenCode 数据到 ~/.local/share/opencode。
需设置环境变量: HF_TOKEN, OPENCODE_DATASET_REPO。
"""
import os
import re
import shutil
import sys

def _normalize_repo_id(value):
    """接受 repo_id 或完整 URL,返回 namespace/repo_name。"""
    if not value or not value.strip():
        return None
    value = value.strip()
    # 若是 URL,提取最后两段路径作为 namespace/repo_name
    m = re.search(r"(?:huggingface\.co/datasets/|^)([\w.-]+/[\w.-]+)/?$", value)
    if m:
        return m.group(1)
    # 已是 namespace/repo_name 形式
    if "/" in value:
        return value
    return None

def main():
    token = os.environ.get("HF_TOKEN")
    raw = os.environ.get("OPENCODE_DATASET_REPO")
    repo_id = _normalize_repo_id(raw)
    data_dir = os.path.expanduser("~/.local/share/opencode")

    if not token or not repo_id:
        return 0

    try:
        from huggingface_hub import HfApi, snapshot_download
    except ImportError:
        print("restore: huggingface_hub not installed, skip restore", file=sys.stderr)
        return 0

    try:
        api = HfApi(token=token)
        files = api.list_repo_files(repo_id, repo_type="dataset")
        if not files or set(files) <= {".gitattributes"}:
            return 0
    except Exception as e:
        print(f"restore: list repo failed ({e}), skip restore", file=sys.stderr)
        return 0

    os.makedirs(data_dir, exist_ok=True)
    tmp_dir = data_dir + ".restore_tmp"
    try:
        snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            local_dir=tmp_dir,
            token=token,
        )
        for name in os.listdir(tmp_dir):
            if name == ".gitattributes":
                continue
            src = os.path.join(tmp_dir, name)
            dst = os.path.join(data_dir, name)
            if os.path.isdir(src):
                if os.path.exists(dst):
                    shutil.rmtree(dst, ignore_errors=True)
                shutil.copytree(src, dst)
            else:
                shutil.copy2(src, dst)
    finally:
        if os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir, ignore_errors=True)
    return 0

if __name__ == "__main__":
    sys.exit(main())