Spaces:
Running
Running
| """OpenEEGBench eval worker — the engine behind the arena. | |
| The eval engine. By default the **Space runs this itself** in a background | |
| process (see the Dockerfile): a CPU ``--watch`` loop in an *isolated venv* | |
| (open-eeg-bench + its heavy deps, kept out of the web app's pinned stack) that | |
| polls the queue and publishes results. **A GPU is optional** — linear/ridge | |
| probing and inference run fine on CPU (the default); heavier strategies like | |
| ``full_finetune`` are just slower. It can also run on any other machine that has | |
| ``open-eeg-bench`` installed and an ``HF_TOKEN`` with ``braindecode`` write: | |
| pip install -r scripts/worker-requirements.txt # open-eeg-bench (+ torch, braindecode) | |
| export HF_TOKEN=hf_xxx | |
| python -m scripts.run_eval_worker --dry-run # list pending, do nothing | |
| python -m scripts.run_eval_worker --limit 1 # process the queue once, CPU | |
| python -m scripts.run_eval_worker --watch --interval 300 # run as a service (what the Space does) | |
| python -m scripts.run_eval_worker --device cuda # ... use a GPU if you have one | |
| python -m scripts.run_eval_worker --slurm --infra-folder ./oeb-cache # ... or fan out via SLURM | |
| For each PENDING request it: | |
| 1. marks the request RUNNING (in braindecode/requests), | |
| 2. runs ``oeb.benchmark(**benchmark_kwargs)``, | |
| 3. maps the result DataFrame to the leaderboard schema (results_mapping), | |
| 4. appends the row(s) to the public braindecode/contents dataset, | |
| 5. marks the request FINISHED (or FAILED with the error). | |
| It is one-shot by design (process the current queue, then exit) so it can be run | |
| from cron / a systemd timer / a SLURM job. ``open_eeg_bench`` is imported lazily, | |
| so this module imports fine without it (e.g. for --dry-run or tests). | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from huggingface_hub import HfApi | |
| # Import only import-light, dependency-free app modules so this worker can run in | |
| # its own isolated venv (open-eeg-bench + its heavy deps) without pulling the web | |
| # app's pinned stack. base.py is pure stdlib; results_mapping is pure Python. | |
| from app.config.base import HF_TOKEN, HF_ORGANIZATION | |
| from app.services.results_mapping import map_oeb_results_to_contents | |
| QUEUE_REPO = f"{HF_ORGANIZATION}/requests" | |
| AGGREGATED_REPO = f"{HF_ORGANIZATION}/contents" | |
| hf_api = HfApi(token=HF_TOKEN) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger("eval_worker") | |
| def _now() -> str: | |
| return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") | |
| def load_pending(limit=None): | |
| """Return ``[(path_in_repo, entry), ...]`` for PENDING open-eeg-bench requests. | |
| A missing queue dataset (no submissions yet) is treated as empty. | |
| """ | |
| try: | |
| local_dir = hf_api.snapshot_download( | |
| repo_id=QUEUE_REPO, repo_type="dataset", token=HF_TOKEN | |
| ) | |
| except Exception as e: | |
| logger.info("Queue %s not available (%s); nothing to do.", QUEUE_REPO, e) | |
| return [] | |
| pending = [] | |
| for p in sorted(Path(local_dir).glob("**/*.json")): | |
| try: | |
| entry = json.loads(p.read_text()) | |
| except Exception: | |
| continue | |
| if entry.get("framework") == "open-eeg-bench" and entry.get("status") == "PENDING": | |
| pending.append((str(p.relative_to(local_dir)), entry)) | |
| return pending[:limit] if limit else pending | |
| def set_status(path_in_repo, entry, status, **extra): | |
| """Re-upload the request JSON with an updated status (RUNNING/FINISHED/FAILED).""" | |
| updated = {**entry, "status": status, **extra} | |
| hf_api.upload_file( | |
| path_or_fileobj=json.dumps(updated, indent=2).encode("utf-8"), | |
| path_in_repo=path_in_repo, | |
| repo_id=QUEUE_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message=f"{updated.get('model_name', '?')}: {status}", | |
| ) | |
| return updated | |
| def run_benchmark(entry, device, infra): | |
| """Run oeb.benchmark() for a request and map the result to contents rows.""" | |
| import open_eeg_bench as oeb # heavy, lazy import (torch/braindecode) | |
| df = oeb.benchmark(device=device, infra=infra, **entry["benchmark_kwargs"]) | |
| return map_oeb_results_to_contents(df, entry) | |
| def _load_contents_df(): | |
| """Read the current contents (the single ``train.parquet``); empty if absent.""" | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| path = hf_hub_download( | |
| repo_id=AGGREGATED_REPO, filename="train.parquet", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| return pd.read_parquet(path) | |
| except Exception: | |
| return pd.DataFrame() | |
| def publish(rows): | |
| """Append rows to the public contents dataset (create it if needed). | |
| Dedupes by (model, adapter) keeping the latest, so re-submissions update in | |
| place. Stored as a single ``train.parquet`` that ``load_dataset(repo)["train"]`` | |
| reads — written with pandas to stay robust across datasets/pyarrow versions. | |
| """ | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| combined = pd.concat([_load_contents_df(), pd.DataFrame(rows)], ignore_index=True) | |
| if {"fullname", "adapter"}.issubset(combined.columns): | |
| combined = combined.drop_duplicates(subset=["fullname", "adapter"], keep="last") | |
| hf_api.create_repo( | |
| repo_id=AGGREGATED_REPO, repo_type="dataset", | |
| private=False, exist_ok=True, token=HF_TOKEN, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f: | |
| tmp = f.name | |
| try: | |
| combined.to_parquet(tmp, index=False) | |
| hf_api.upload_file( | |
| path_or_fileobj=tmp, path_in_repo="train.parquet", | |
| repo_id=AGGREGATED_REPO, repo_type="dataset", token=HF_TOKEN, | |
| commit_message="Update leaderboard contents", | |
| ) | |
| finally: | |
| os.unlink(tmp) | |
| return len(combined) | |
| def process_queue(device="cpu", limit=None, infra=None): | |
| """Run every PENDING submission once. Returns how many were processed.""" | |
| pending = load_pending(limit) | |
| logger.info("Found %d pending submission(s).", len(pending)) | |
| for path_in_repo, entry in pending: | |
| name = entry.get("model_name", "?") | |
| logger.info("RUNNING %s", name) | |
| set_status(path_in_repo, entry, "RUNNING", started_time=_now()) | |
| try: | |
| rows = run_benchmark(entry, device, infra) | |
| if not rows: | |
| raise RuntimeError("benchmark produced no completed results") | |
| total = publish(rows) | |
| set_status(path_in_repo, entry, "FINISHED", finished_time=_now(), n_rows=len(rows)) | |
| logger.info("FINISHED %s — published %d row(s); contents now has %d.", name, len(rows), total) | |
| except Exception as e: # noqa: BLE001 — record the failure on the request and move on | |
| logger.exception("FAILED %s", name) | |
| set_status(path_in_repo, entry, "FAILED", finished_time=_now(), error=str(e)[:500]) | |
| return len(pending) | |
| def main(): | |
| import time | |
| ap = argparse.ArgumentParser( | |
| description="Run queued OpenEEGBench submissions and publish results." | |
| ) | |
| ap.add_argument("--dry-run", action="store_true", help="List pending submissions and exit.") | |
| ap.add_argument("--watch", action="store_true", help="Keep polling the queue (run as a long-lived service).") | |
| ap.add_argument("--interval", type=int, default=300, help="Seconds between polls in --watch mode (default 300).") | |
| ap.add_argument("--device", default="cpu", help="Torch device for benchmark() (default: cpu; use cuda if a GPU is available).") | |
| ap.add_argument("--limit", type=int, default=None, help="Max submissions to process per cycle.") | |
| ap.add_argument("--infra-folder", default=None, help="oeb cache/results folder (enables caching + SLURM).") | |
| ap.add_argument("--slurm", action="store_true", help="Submit experiments via SLURM (oeb infra cluster=slurm).") | |
| args = ap.parse_args() | |
| if args.dry_run: | |
| for _, e in load_pending(args.limit): | |
| logger.info("PENDING %s — %s", e.get("model_name"), (e.get("benchmark_kwargs") or {}).get("model_cls")) | |
| return | |
| infra = {} | |
| if args.infra_folder: | |
| infra["folder"] = args.infra_folder | |
| if args.slurm: | |
| infra["cluster"] = "slurm" | |
| infra = infra or None | |
| if args.watch: | |
| logger.info( | |
| "Worker watching %s every %ds (device=%s, limit=%s).", | |
| QUEUE_REPO, args.interval, args.device, args.limit, | |
| ) | |
| while True: | |
| try: | |
| process_queue(args.device, args.limit, infra) | |
| except Exception: # noqa: BLE001 — never let one cycle kill the worker | |
| logger.exception("worker cycle error") | |
| time.sleep(args.interval) | |
| else: | |
| process_queue(args.device, args.limit, infra) | |
| if __name__ == "__main__": | |
| main() | |