File size: 4,355 Bytes
28a3671 1259d99 28a3671 1259d99 28a3671 1259d99 28a3671 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | from __future__ import annotations
from pathlib import Path, PurePosixPath
from typing import Iterable
from huggingface_hub import CommitOperationAdd, HfApi
try:
from .validator import DOMAIN_TOKEN_RE, PreparedSubmission, TASK_ID_RE, normalize_domain_token
except ImportError:
from validator import DOMAIN_TOKEN_RE, PreparedSubmission, TASK_ID_RE, normalize_domain_token
DEFAULT_REPO_ID = 'InternScience/ResearchClawBench'
TOKEN_ENV_KEYS = (
'RCB_SPACE_HF_TOKEN',
'HF_TOKEN',
'HUGGINGFACEHUB_API_TOKEN',
'HUGGINGFACE_TOKEN',
)
def load_hf_token() -> str | None:
import os
for key in TOKEN_ENV_KEYS:
value = os.environ.get(key)
if value:
return value
return None
def list_existing_task_ids(repo_id: str = DEFAULT_REPO_ID, token: str | None = None) -> set[str]:
api = HfApi(token=token)
task_ids: set[str] = set()
for remote_path in api.list_repo_files(repo_id=repo_id, repo_type='dataset', token=token):
parts = PurePosixPath(remote_path).parts
if len(parts) >= 2 and parts[0] == 'tasks':
task_ids.add(parts[1])
return task_ids
def get_repo_head_sha(repo_id: str = DEFAULT_REPO_ID, token: str | None = None) -> str:
api = HfApi(token=token)
info = api.repo_info(repo_id=repo_id, repo_type='dataset', revision='main', token=token)
if not getattr(info, 'sha', None):
raise RuntimeError(f'Failed to fetch HEAD SHA for dataset repo {repo_id}.')
return info.sha
def allocate_next_task_id(domain: str, existing_task_ids: Iterable[str]) -> str:
domain = normalize_domain_token(domain)
if not DOMAIN_TOKEN_RE.fullmatch(domain):
raise ValueError(
'Domain must start with a letter and contain only letters, numbers, or hyphens '
f'after normalization. Got: {domain!r}'
)
used_numbers = []
for task_id in existing_task_ids:
match = TASK_ID_RE.match(task_id)
if match and match.group(1) == domain:
used_numbers.append(int(match.group(2)))
next_number = (max(used_numbers) + 1) if used_numbers else 0
if next_number > 999:
raise ValueError(f'No task IDs left for domain {domain}.')
return f'{domain}_{next_number:03d}'
def build_commit_description(prepared: PreparedSubmission) -> str:
metadata = prepared.metadata
lines = [
f'Submitter: {metadata.submitter}',
f'Contact email: {metadata.email}',
f'Domain: {metadata.domain}',
f'Assigned task id: {prepared.assigned_task_id}',
f'Paper title: {metadata.paper_title}',
f'Paper URL/DOI: {metadata.paper_url}',
f'Archive files: {prepared.archive_stats.file_count}',
f'Archive total bytes: {prepared.archive_stats.total_bytes}',
]
if metadata.notes.strip():
lines.extend(['', 'Submitter notes:', metadata.notes.strip()])
lines.extend(['', 'This PR was created automatically by the ResearchClawBench submission Space after passing format validation.'])
return '\n'.join(lines)
def create_dataset_pr(
prepared: PreparedSubmission,
*,
repo_id: str = DEFAULT_REPO_ID,
token: str | None = None,
parent_commit: str | None = None,
):
token = token or load_hf_token()
if not token:
raise RuntimeError('No Hugging Face write token configured. Set RCB_SPACE_HF_TOKEN or HF_TOKEN.')
staged_task_dir = Path(prepared.staged_task_dir)
if not staged_task_dir.is_dir():
raise RuntimeError(f'Staged task directory does not exist: {staged_task_dir}')
operations = []
for path in sorted(staged_task_dir.rglob('*')):
if not path.is_file():
continue
rel_path = path.relative_to(staged_task_dir).as_posix()
operations.append(
CommitOperationAdd(
path_in_repo=f'tasks/{prepared.assigned_task_id}/{rel_path}',
path_or_fileobj=str(path),
)
)
api = HfApi(token=token)
return api.create_commit(
repo_id=repo_id,
repo_type='dataset',
operations=operations,
commit_message=f'Add task submission {prepared.assigned_task_id}',
commit_description=build_commit_description(prepared),
token=token,
create_pr=True,
revision='main',
parent_commit=parent_commit,
)
|