AdithyaSK's picture
AdithyaSK HF Staff
Recognise registry.json as a positive Harbor-dataset signal
e587765
Raw
History Blame Contribute Delete
5.43 kB
"""Discover Harbor task-spec datasets on the Hugging Face Hub.
Harbor datasets are tagged `harbor` on the Hub — the same filter as
https://huggingface.co/datasets?other=harbor . This module lists them (fast,
no per-dataset round-trips) and computes per-dataset task counts on demand
(one cheap `list_repo_files` call, memoised).
All listing is done live against the Hub so the UI always reflects the latest
published datasets (no stale snapshot).
"""
from __future__ import annotations
import logging
import os
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
_HARBOR_TAG = "harbor"
@dataclass(slots=True)
class HubDataset:
id: str
downloads: int = 0
likes: int = 0
updated: str | None = None
private: bool = False
def as_dict(self) -> dict:
return {
"id": self.id,
"downloads": self.downloads,
"likes": self.likes,
"updated": self.updated,
"private": self.private,
}
def _token() -> str | None:
return os.environ.get("HF_TOKEN") or None
def list_harbor_datasets(query: str | None = None, sort: str = "downloads",
limit: int = 500) -> list[HubDataset]:
"""List datasets tagged `harbor` on the Hub. Always live (no caching).
`sort` ∈ {downloads, likes, lastModified, trending}. `query` filters by
substring on the dataset id (server-side search)."""
from huggingface_hub import HfApi
api = HfApi(token=_token())
# `filter=` matches the `other:harbor` tag used by the Hub UI.
kwargs: dict = {"filter": _HARBOR_TAG, "limit": limit}
if sort in ("downloads", "likes", "lastModified", "trendingScore"):
kwargs["sort"] = sort
if query:
kwargs["search"] = query
out: list[HubDataset] = []
for d in api.list_datasets(**kwargs):
lm = getattr(d, "last_modified", None)
out.append(HubDataset(
id=d.id,
downloads=int(getattr(d, "downloads", 0) or 0),
likes=int(getattr(d, "likes", 0) or 0),
updated=lm.isoformat() if lm else None,
private=bool(getattr(d, "private", False)),
))
return out
# task-id memo: {(id, rev): (ids, ts)} — derived from a shallow tree listing,
# never a download. Short TTL so freshly-pushed tasks still surface.
_TASKS_CACHE: dict[tuple[str, str], tuple[list[str], float]] = {}
_TASKS_TTL = 120.0 # seconds
def _is_dir(entry) -> bool:
return entry.__class__.__name__ == "RepoFolder"
def list_hf_tasks(dataset_id: str, revision: str | None = None, *, ttl: float = _TASKS_TTL) -> list[str]:
"""Task ids in a Hub dataset WITHOUT downloading it.
Uses *shallow* tree listings so even 2k-task datasets resolve in ~1 API call
instead of walking every file: if a top-level `tasks/` folder exists we list
its immediate children (Repo2RLEnv's nested layout); otherwise we treat the
top-level folders as flat task dirs. This is the fix for huge datasets that
used to hang while the whole repo was enumerated/downloaded."""
key = (dataset_id, revision or "head")
now = time.time()
hit = _TASKS_CACHE.get(key)
if hit and (now - hit[1]) < ttl:
return hit[0]
from huggingface_hub import HfApi
api = HfApi(token=_token())
root = list(api.list_repo_tree(dataset_id, repo_type="dataset", revision=revision, recursive=False))
names = {e.path: e for e in root}
# `registry.json` at the root is a positive signal that this is a Harbor
# dataset (Repo2RLEnv pushes it; harbor's --registry-path consumes it).
# It's *not* required — terminal-bench-2.0, dabstep-harbor, titanbench all
# ship without one — but its presence skips the task.toml sampling below.
has_registry = "registry.json" in names
if "tasks" in names and _is_dir(names["tasks"]):
sub = api.list_repo_tree(dataset_id, "tasks", repo_type="dataset", revision=revision, recursive=False)
ids = sorted(e.path.split("/")[-1] for e in sub if _is_dir(e))
else:
# Flat layout: top-level folders MAY be tasks (skip dotfiles/README/etc.).
# Some datasets (e.g. TaskTrove) have top-level dirs that aren't Harbor
# tasks — they hold `tasks.parquet` or similar. Verify by sampling the
# first few candidates for a `task.toml`. If `registry.json` is at the
# root we already know this is a Harbor dataset and skip the check.
candidates = sorted(e.path for e in root if _is_dir(e) and not e.path.startswith("."))
if has_registry:
ids = candidates
else:
ids = []
for sample in candidates[:3]:
try:
sub = list(api.list_repo_tree(dataset_id, sample, repo_type="dataset", revision=revision, recursive=False))
except Exception: # noqa: BLE001
continue
if any(getattr(e, "path", "").endswith("task.toml") for e in sub):
ids = candidates
break
_TASKS_CACHE[key] = (ids, now)
return ids
def count_tasks(dataset_id: str) -> int:
"""Number of Harbor tasks in a Hub dataset (shallow listing, memoised)."""
try:
return len(list_hf_tasks(dataset_id))
except Exception as exc: # noqa: BLE001
logger.warning("count_tasks(%s) failed: %s", dataset_id, exc)
return -1