File size: 8,990 Bytes
356d76f
4db3be0
c54d8a2
 
 
 
 
 
 
4db3be0
356d76f
 
c54d8a2
 
 
 
 
356d76f
 
 
 
 
 
 
 
 
 
 
4db3be0
 
 
 
356d76f
4db3be0
 
c54d8a2
 
 
 
 
 
4db3be0
 
c54d8a2
 
 
 
356d76f
 
 
 
 
 
 
 
 
 
4db3be0
356d76f
 
 
 
 
 
 
 
 
4db3be0
 
356d76f
4db3be0
 
 
 
356d76f
 
 
4db3be0
 
356d76f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db3be0
 
 
356d76f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c54d8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4db3be0
c54d8a2
 
356d76f
 
 
 
c54d8a2
 
 
 
356d76f
 
4db3be0
 
356d76f
c54d8a2
356d76f
 
 
 
 
 
 
 
 
 
c54d8a2
 
 
 
 
 
 
 
 
 
 
 
 
4db3be0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""OpenEEGBench eval worker — the engine behind the arena.

The eval engine. By default the **Space runs this itself** in a background
process (see the Dockerfile): a CPU ``--watch`` loop in an *isolated venv*
(open-eeg-bench + its heavy deps, kept out of the web app's pinned stack) that
polls the queue and publishes results. **A GPU is optional** — linear/ridge
probing and inference run fine on CPU (the default); heavier strategies like
``full_finetune`` are just slower. It can also run on any other machine that has
``open-eeg-bench`` installed and an ``HF_TOKEN`` with ``braindecode`` write:

    pip install -r scripts/worker-requirements.txt   # open-eeg-bench (+ torch, braindecode)
    export HF_TOKEN=hf_xxx
    python -m scripts.run_eval_worker --dry-run                 # list pending, do nothing
    python -m scripts.run_eval_worker --limit 1                 # process the queue once, CPU
    python -m scripts.run_eval_worker --watch --interval 300    # run as a service (what the Space does)
    python -m scripts.run_eval_worker --device cuda             # ... use a GPU if you have one
    python -m scripts.run_eval_worker --slurm --infra-folder ./oeb-cache   # ... or fan out via SLURM

For each PENDING request it:
  1. marks the request RUNNING (in braindecode/requests),
  2. runs ``oeb.benchmark(**benchmark_kwargs)``,
  3. maps the result DataFrame to the leaderboard schema (results_mapping),
  4. appends the row(s) to the public braindecode/contents dataset,
  5. marks the request FINISHED (or FAILED with the error).

It is one-shot by design (process the current queue, then exit) so it can be run
from cron / a systemd timer / a SLURM job. ``open_eeg_bench`` is imported lazily,
so this module imports fine without it (e.g. for --dry-run or tests).
"""
import argparse
import json
import logging
from datetime import datetime, timezone
from pathlib import Path

from huggingface_hub import HfApi

# Import only import-light, dependency-free app modules so this worker can run in
# its own isolated venv (open-eeg-bench + its heavy deps) without pulling the web
# app's pinned stack. base.py is pure stdlib; results_mapping is pure Python.
from app.config.base import HF_TOKEN, HF_ORGANIZATION
from app.services.results_mapping import map_oeb_results_to_contents

QUEUE_REPO = f"{HF_ORGANIZATION}/requests"
AGGREGATED_REPO = f"{HF_ORGANIZATION}/contents"
hf_api = HfApi(token=HF_TOKEN)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("eval_worker")


def _now() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def load_pending(limit=None):
    """Return ``[(path_in_repo, entry), ...]`` for PENDING open-eeg-bench requests.

    A missing queue dataset (no submissions yet) is treated as empty.
    """
    try:
        local_dir = hf_api.snapshot_download(
            repo_id=QUEUE_REPO, repo_type="dataset", token=HF_TOKEN
        )
    except Exception as e:
        logger.info("Queue %s not available (%s); nothing to do.", QUEUE_REPO, e)
        return []

    pending = []
    for p in sorted(Path(local_dir).glob("**/*.json")):
        try:
            entry = json.loads(p.read_text())
        except Exception:
            continue
        if entry.get("framework") == "open-eeg-bench" and entry.get("status") == "PENDING":
            pending.append((str(p.relative_to(local_dir)), entry))
    return pending[:limit] if limit else pending


def set_status(path_in_repo, entry, status, **extra):
    """Re-upload the request JSON with an updated status (RUNNING/FINISHED/FAILED)."""
    updated = {**entry, "status": status, **extra}
    hf_api.upload_file(
        path_or_fileobj=json.dumps(updated, indent=2).encode("utf-8"),
        path_in_repo=path_in_repo,
        repo_id=QUEUE_REPO,
        repo_type="dataset",
        token=HF_TOKEN,
        commit_message=f"{updated.get('model_name', '?')}: {status}",
    )
    return updated


def run_benchmark(entry, device, infra):
    """Run oeb.benchmark() for a request and map the result to contents rows."""
    import open_eeg_bench as oeb  # heavy, lazy import (torch/braindecode)

    df = oeb.benchmark(device=device, infra=infra, **entry["benchmark_kwargs"])
    return map_oeb_results_to_contents(df, entry)


def _load_contents_df():
    """Read the current contents (the single ``train.parquet``); empty if absent."""
    import pandas as pd
    from huggingface_hub import hf_hub_download

    try:
        path = hf_hub_download(
            repo_id=AGGREGATED_REPO, filename="train.parquet",
            repo_type="dataset", token=HF_TOKEN,
        )
        return pd.read_parquet(path)
    except Exception:
        return pd.DataFrame()


def publish(rows):
    """Append rows to the public contents dataset (create it if needed).

    Dedupes by (model, adapter) keeping the latest, so re-submissions update in
    place. Stored as a single ``train.parquet`` that ``load_dataset(repo)["train"]``
    reads — written with pandas to stay robust across datasets/pyarrow versions.
    """
    import os
    import tempfile
    import pandas as pd

    combined = pd.concat([_load_contents_df(), pd.DataFrame(rows)], ignore_index=True)
    if {"fullname", "adapter"}.issubset(combined.columns):
        combined = combined.drop_duplicates(subset=["fullname", "adapter"], keep="last")

    hf_api.create_repo(
        repo_id=AGGREGATED_REPO, repo_type="dataset",
        private=False, exist_ok=True, token=HF_TOKEN,
    )
    with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f:
        tmp = f.name
    try:
        combined.to_parquet(tmp, index=False)
        hf_api.upload_file(
            path_or_fileobj=tmp, path_in_repo="train.parquet",
            repo_id=AGGREGATED_REPO, repo_type="dataset", token=HF_TOKEN,
            commit_message="Update leaderboard contents",
        )
    finally:
        os.unlink(tmp)
    return len(combined)


def process_queue(device="cpu", limit=None, infra=None):
    """Run every PENDING submission once. Returns how many were processed."""
    pending = load_pending(limit)
    logger.info("Found %d pending submission(s).", len(pending))
    for path_in_repo, entry in pending:
        name = entry.get("model_name", "?")
        logger.info("RUNNING %s", name)
        set_status(path_in_repo, entry, "RUNNING", started_time=_now())
        try:
            rows = run_benchmark(entry, device, infra)
            if not rows:
                raise RuntimeError("benchmark produced no completed results")
            total = publish(rows)
            set_status(path_in_repo, entry, "FINISHED", finished_time=_now(), n_rows=len(rows))
            logger.info("FINISHED %s — published %d row(s); contents now has %d.", name, len(rows), total)
        except Exception as e:  # noqa: BLE001 — record the failure on the request and move on
            logger.exception("FAILED %s", name)
            set_status(path_in_repo, entry, "FAILED", finished_time=_now(), error=str(e)[:500])
    return len(pending)


def main():
    import time

    ap = argparse.ArgumentParser(
        description="Run queued OpenEEGBench submissions and publish results."
    )
    ap.add_argument("--dry-run", action="store_true", help="List pending submissions and exit.")
    ap.add_argument("--watch", action="store_true", help="Keep polling the queue (run as a long-lived service).")
    ap.add_argument("--interval", type=int, default=300, help="Seconds between polls in --watch mode (default 300).")
    ap.add_argument("--device", default="cpu", help="Torch device for benchmark() (default: cpu; use cuda if a GPU is available).")
    ap.add_argument("--limit", type=int, default=None, help="Max submissions to process per cycle.")
    ap.add_argument("--infra-folder", default=None, help="oeb cache/results folder (enables caching + SLURM).")
    ap.add_argument("--slurm", action="store_true", help="Submit experiments via SLURM (oeb infra cluster=slurm).")
    args = ap.parse_args()

    if args.dry_run:
        for _, e in load_pending(args.limit):
            logger.info("PENDING %s — %s", e.get("model_name"), (e.get("benchmark_kwargs") or {}).get("model_cls"))
        return

    infra = {}
    if args.infra_folder:
        infra["folder"] = args.infra_folder
    if args.slurm:
        infra["cluster"] = "slurm"
    infra = infra or None

    if args.watch:
        logger.info(
            "Worker watching %s every %ds (device=%s, limit=%s).",
            QUEUE_REPO, args.interval, args.device, args.limit,
        )
        while True:
            try:
                process_queue(args.device, args.limit, infra)
            except Exception:  # noqa: BLE001 — never let one cycle kill the worker
                logger.exception("worker cycle error")
            time.sleep(args.interval)
    else:
        process_queue(args.device, args.limit, infra)


if __name__ == "__main__":
    main()