File size: 4,923 Bytes
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9255a5f
 
 
 
d6c8a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations

import json
import os
import threading
from typing import Any, Literal

from datasets import Dataset, load_dataset

TeamName = Literal["blue", "red"]

DATA_DIR_ENV = "HACKATHON_DATA_DIR"
BLUE_PATH_ENV = "HACKATHON_BLUE_DATA_PATH"
RED_PATH_ENV = "HACKATHON_RED_DATA_PATH"
HF_DATASET_ENV = "HACKATHON_HF_DATASET"
HF_BLUE_DATASET_ENV = "HACKATHON_HF_BLUE_DATASET"
HF_RED_DATASET_ENV = "HACKATHON_HF_RED_DATASET"
HF_TOKEN_ENV = "HF_TOKEN_SUBMISSIONS"
HF_TOKEN_FALLBACK_ENV = "HF_TOKEN"

DEFAULT_DATA_DIR = os.environ.get(DATA_DIR_ENV, os.path.join(os.getcwd(), "hackathon-data"))
DEFAULT_BLUE_PATH = os.environ.get(BLUE_PATH_ENV, os.path.join(DEFAULT_DATA_DIR, "blue_submissions.json"))
DEFAULT_RED_PATH = os.environ.get(RED_PATH_ENV, os.path.join(DEFAULT_DATA_DIR, "red_submissions.json"))

_TEAM_PATHS = {
    "blue": DEFAULT_BLUE_PATH,
    "red": DEFAULT_RED_PATH,
}

_LOCK = threading.Lock()


def _resolve_hf_dataset(team: TeamName) -> str | None:
    dataset_name = (
        os.environ.get(HF_BLUE_DATASET_ENV, "").strip()
        if team == "blue"
        else os.environ.get(HF_RED_DATASET_ENV, "").strip()
    )
    if dataset_name:
        return dataset_name

    legacy_dataset = os.environ.get(HF_DATASET_ENV, "").strip()
    if legacy_dataset:
        raise ValueError(
            "Set HACKATHON_HF_BLUE_DATASET and HACKATHON_HF_RED_DATASET for separate datasets. "
            f"HACKATHON_HF_DATASET is no longer supported: {legacy_dataset}"
        )
    return None


def _require_hf_token() -> str:
    token = os.environ.get(HF_TOKEN_ENV) or os.environ.get(HF_TOKEN_FALLBACK_ENV)
    if not token:
        raise ValueError(
            "HF_TOKEN_SUBMISSIONS (or HF_TOKEN) is required to access the private submissions dataset."
        )
    return token


def _load_hf_submissions(dataset_name: str, token: str) -> list[dict[str, Any]]:
    try:
        dataset = load_dataset(dataset_name, split="train", token=token)
    except Exception as exc:
        message = str(exc).lower()
        if any(code in message for code in ("401", "403", "permission", "unauthorized", "forbidden")):
            raise ValueError(
                f"HF_TOKEN_SUBMISSIONS does not have access to the private dataset: {dataset_name}"
            ) from exc
        if any(
            text in message
            for text in (
                "not found", "404", "doesn't exist", "no such dataset", "split",
                "doesn't contain any data", "an error occurred while generating the dataset",
            )
        ):
            return []
        raise
    return dataset.to_list()


def _save_hf_submissions(dataset_name: str, token: str, submissions: list[dict]) -> None:
    dataset = Dataset.from_list(submissions)
    dataset.push_to_hub(dataset_name, token=token, private=True)


def _resolve_data_path(team: TeamName, data_path: str | None) -> str:
    if data_path:
        return data_path
    if team not in _TEAM_PATHS:
        raise ValueError(f"Unknown team: {team}")
    return _TEAM_PATHS[team]


def load_submissions(team: TeamName, data_path: str | None = None) -> list[dict[str, Any]]:
    dataset_name = _resolve_hf_dataset(team)
    if dataset_name:
        token = _require_hf_token()
        return _load_hf_submissions(dataset_name, token)

    resolved_path = _resolve_data_path(team, data_path)
    if not os.path.exists(resolved_path):
        return []

    with _LOCK:
        with open(resolved_path, "r") as f:
            data = json.load(f)

    if not isinstance(data, list):
        raise ValueError(f"Expected a list in {resolved_path}.")

    return data


def save_submissions(
    team: TeamName,
    submissions: list[dict[str, Any]],
    data_path: str | None = None,
) -> None:
    dataset_name = _resolve_hf_dataset(team)
    if dataset_name:
        token = _require_hf_token()
        _save_hf_submissions(dataset_name, token, submissions)
        return

    resolved_path = _resolve_data_path(team, data_path)
    data_dir = os.path.dirname(resolved_path)
    if data_dir:
        os.makedirs(data_dir, exist_ok=True)

    tmp_path = f"{resolved_path}.tmp"
    with _LOCK:
        with open(tmp_path, "w") as f:
            json.dump(submissions, f, indent=2)
        os.replace(tmp_path, resolved_path)


def append_submission(
    team: TeamName,
    submission: dict[str, Any],
    data_path: str | None = None,
) -> list[dict[str, Any]]:
    dataset_name = _resolve_hf_dataset(team)
    if dataset_name:
        token = _require_hf_token()
        with _LOCK:
            submissions = _load_hf_submissions(dataset_name, token)
            submissions.append(submission)
            _save_hf_submissions(dataset_name, token, submissions)
        return submissions

    submissions = load_submissions(team, data_path)
    submissions.append(submission)
    save_submissions(team, submissions, data_path)
    return submissions