File size: 4,355 Bytes
28a3671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259d99
 
 
 
 
 
 
 
28a3671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259d99
28a3671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259d99
28a3671
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations

from pathlib import Path, PurePosixPath
from typing import Iterable

from huggingface_hub import CommitOperationAdd, HfApi

try:
    from .validator import DOMAIN_TOKEN_RE, PreparedSubmission, TASK_ID_RE, normalize_domain_token
except ImportError:
    from validator import DOMAIN_TOKEN_RE, PreparedSubmission, TASK_ID_RE, normalize_domain_token

DEFAULT_REPO_ID = 'InternScience/ResearchClawBench'
TOKEN_ENV_KEYS = (
    'RCB_SPACE_HF_TOKEN',
    'HF_TOKEN',
    'HUGGINGFACEHUB_API_TOKEN',
    'HUGGINGFACE_TOKEN',
)


def load_hf_token() -> str | None:
    import os

    for key in TOKEN_ENV_KEYS:
        value = os.environ.get(key)
        if value:
            return value
    return None


def list_existing_task_ids(repo_id: str = DEFAULT_REPO_ID, token: str | None = None) -> set[str]:
    api = HfApi(token=token)
    task_ids: set[str] = set()
    for remote_path in api.list_repo_files(repo_id=repo_id, repo_type='dataset', token=token):
        parts = PurePosixPath(remote_path).parts
        if len(parts) >= 2 and parts[0] == 'tasks':
            task_ids.add(parts[1])
    return task_ids


def get_repo_head_sha(repo_id: str = DEFAULT_REPO_ID, token: str | None = None) -> str:
    api = HfApi(token=token)
    info = api.repo_info(repo_id=repo_id, repo_type='dataset', revision='main', token=token)
    if not getattr(info, 'sha', None):
        raise RuntimeError(f'Failed to fetch HEAD SHA for dataset repo {repo_id}.')
    return info.sha


def allocate_next_task_id(domain: str, existing_task_ids: Iterable[str]) -> str:
    domain = normalize_domain_token(domain)
    if not DOMAIN_TOKEN_RE.fullmatch(domain):
        raise ValueError(
            'Domain must start with a letter and contain only letters, numbers, or hyphens '
            f'after normalization. Got: {domain!r}'
        )
    used_numbers = []
    for task_id in existing_task_ids:
        match = TASK_ID_RE.match(task_id)
        if match and match.group(1) == domain:
            used_numbers.append(int(match.group(2)))
    next_number = (max(used_numbers) + 1) if used_numbers else 0
    if next_number > 999:
        raise ValueError(f'No task IDs left for domain {domain}.')
    return f'{domain}_{next_number:03d}'


def build_commit_description(prepared: PreparedSubmission) -> str:
    metadata = prepared.metadata
    lines = [
        f'Submitter: {metadata.submitter}',
        f'Contact email: {metadata.email}',
        f'Domain: {metadata.domain}',
        f'Assigned task id: {prepared.assigned_task_id}',
        f'Paper title: {metadata.paper_title}',
        f'Paper URL/DOI: {metadata.paper_url}',
        f'Archive files: {prepared.archive_stats.file_count}',
        f'Archive total bytes: {prepared.archive_stats.total_bytes}',
    ]
    if metadata.notes.strip():
        lines.extend(['', 'Submitter notes:', metadata.notes.strip()])
    lines.extend(['', 'This PR was created automatically by the ResearchClawBench submission Space after passing format validation.'])
    return '\n'.join(lines)


def create_dataset_pr(
    prepared: PreparedSubmission,
    *,
    repo_id: str = DEFAULT_REPO_ID,
    token: str | None = None,
    parent_commit: str | None = None,
):
    token = token or load_hf_token()
    if not token:
        raise RuntimeError('No Hugging Face write token configured. Set RCB_SPACE_HF_TOKEN or HF_TOKEN.')

    staged_task_dir = Path(prepared.staged_task_dir)
    if not staged_task_dir.is_dir():
        raise RuntimeError(f'Staged task directory does not exist: {staged_task_dir}')

    operations = []
    for path in sorted(staged_task_dir.rglob('*')):
        if not path.is_file():
            continue
        rel_path = path.relative_to(staged_task_dir).as_posix()
        operations.append(
            CommitOperationAdd(
                path_in_repo=f'tasks/{prepared.assigned_task_id}/{rel_path}',
                path_or_fileobj=str(path),
            )
        )

    api = HfApi(token=token)
    return api.create_commit(
        repo_id=repo_id,
        repo_type='dataset',
        operations=operations,
        commit_message=f'Add task submission {prepared.assigned_task_id}',
        commit_description=build_commit_description(prepared),
        token=token,
        create_pr=True,
        revision='main',
        parent_commit=parent_commit,
    )