from __future__ import annotations import argparse import random import re import sys from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any import yaml EXAMPLE_SECTION_RE = re.compile(r"(## Example Data\n\n).*?(?=\n## Dataset Information)", re.DOTALL) METADATA_RE = re.compile(r"\s*```yaml\n(.*?)\n```", re.DOTALL) DEFAULT_TEXT_LIMIT = 225 DEFAULT_SAMPLE_SIZE = 5 DEFAULT_SEED = 42 @dataclass(frozen=True) class TaskReference: dataset_id: str split_name: str def _as_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value return str(value) def _first_present(row: Mapping[str, Any], keys: Sequence[str]) -> Any: for key in keys: if key in row and row[key] is not None: return row[key] raise KeyError(f"none of the expected keys are present: {', '.join(keys)}") def _row_id(row: Mapping[str, Any]) -> str: return _as_text(_first_present(row, ("_id", "id", "query-id", "query_id", "corpus-id", "corpus_id"))) def _row_text(row: Mapping[str, Any]) -> str: text = _as_text( _first_present( row, ( "text", "query", "question", "contents", "content", "document", "passage", "answer", ), ) ).strip() title = _as_text(row.get("title")).strip() if title and text and not text.startswith(title): return f"{title}\n\n{text}" if title and not text: return title return text def _qrel_query_id(row: Mapping[str, Any]) -> str: return _as_text(_first_present(row, ("query-id", "query_id", "qid", "query", "_id"))).strip() def _qrel_corpus_id(row: Mapping[str, Any]) -> str: return _as_text(_first_present(row, ("corpus-id", "corpus_id", "docid", "document_id", "doc_id"))).strip() def _is_positive_qrel(row: Mapping[str, Any]) -> bool: if "score" not in row or row["score"] is None: return True try: return float(row["score"]) > 0 except (TypeError, ValueError): return True def _normalize_visible_text(text: str) -> str: return re.sub(r"\s+", " ", text).strip() def _escape_markdown_cell(text: str) -> str: return text.replace("|", r"\|") def format_example_text(text: str, *, text_limit: int = DEFAULT_TEXT_LIMIT) -> str: stripped = text.strip() full_chars = len(stripped) visible = _normalize_visible_text(stripped) if len(visible) > text_limit: visible = visible[:text_limit].rstrip() return _escape_markdown_cell(f"{visible} ... [truncated {text_limit} chars]({full_chars} chars)") return _escape_markdown_cell(f"{visible} ({full_chars} chars)") def _materialize_by_id(rows: Iterable[Mapping[str, Any]]) -> dict[str, str]: return {_row_id(row): _row_text(row) for row in rows} def _positive_docs_by_query(rows: Iterable[Mapping[str, Any]]) -> dict[str, list[str]]: positives: dict[str, list[str]] = {} for row in rows: if not _is_positive_qrel(row): continue positives.setdefault(_qrel_query_id(row), []).append(_qrel_corpus_id(row)) return positives def build_example_table( *, queries: Iterable[Mapping[str, Any]], corpus: Iterable[Mapping[str, Any]], qrels: Iterable[Mapping[str, Any]], sample_size: int = DEFAULT_SAMPLE_SIZE, seed: int = DEFAULT_SEED, text_limit: int = DEFAULT_TEXT_LIMIT, ) -> str: queries_by_id = _materialize_by_id(queries) corpus_by_id = _materialize_by_id(corpus) positives_by_query = _positive_docs_by_query(qrels) eligible_query_ids = sorted( query_id for query_id, corpus_ids in positives_by_query.items() if query_id in queries_by_id and any(corpus_id in corpus_by_id for corpus_id in corpus_ids) ) if not eligible_query_ids: raise ValueError("no query-positive pairs with matching query and corpus records were found") rng = random.Random(seed) selected_query_ids = rng.sample(eligible_query_ids, k=min(sample_size, len(eligible_query_ids))) lines = [ "| Query | Positive document |", "| --- | --- |", ] for query_id in selected_query_ids: corpus_id = next(corpus_id for corpus_id in positives_by_query[query_id] if corpus_id in corpus_by_id) query_text = format_example_text(queries_by_id[query_id], text_limit=text_limit) document_text = format_example_text(corpus_by_id[corpus_id], text_limit=text_limit) lines.append(f"| {query_text} | {document_text} |") return "\n".join(lines) def load_example_table( *, dataset_id: str, split_name: str, queries_config: str = "queries", corpus_config: str = "corpus", qrels_config: str = "qrels", sample_size: int = DEFAULT_SAMPLE_SIZE, seed: int = DEFAULT_SEED, text_limit: int = DEFAULT_TEXT_LIMIT, ) -> str: queries = _load_dataset_split(dataset_id, queries_config, split_name) corpus = _load_dataset_split(dataset_id, corpus_config, split_name) qrels = _load_dataset_split(dataset_id, qrels_config, split_name) return build_example_table( queries=queries, corpus=corpus, qrels=qrels, sample_size=sample_size, seed=seed, text_limit=text_limit, ) @lru_cache(maxsize=None) def _load_dataset_config(dataset_id: str, config_name: str) -> Any: from datasets import load_dataset return load_dataset(dataset_id, config_name) def _load_dataset_split(dataset_id: str, config_name: str, split_name: str) -> Any: dataset = _load_dataset_config(dataset_id, config_name) try: return dataset[split_name] except KeyError as exc: available = ", ".join(str(split) for split in getattr(dataset, "keys", lambda: [])()) raise KeyError(f"{dataset_id}/{config_name} does not contain split {split_name!r}; available: {available}") from exc def _task_reference_from_doc(path: Path) -> TaskReference: text = path.read_text(encoding="utf-8") match = METADATA_RE.search(text) if not match: raise ValueError(f"missing benchmark task metadata: {path}") metadata = yaml.safe_load(match.group(1)) task_metadata = metadata.get("benchmark_task_metadata") if isinstance(metadata, dict) else None if not isinstance(task_metadata, dict): raise ValueError(f"invalid benchmark task metadata: {path}") dataset_id = task_metadata.get("dataset_id") split_name = task_metadata.get("split_name") or task_metadata.get("task_name") if not dataset_id or not split_name: raise ValueError(f"metadata must include dataset_id and split_name: {path}") return TaskReference(dataset_id=str(dataset_id), split_name=str(split_name)) def _replace_example_section(text: str, table: str) -> str: updated, count = EXAMPLE_SECTION_RE.subn(lambda match: f"{match.group(1)}{table}\n", text, count=1) if count != 1: raise ValueError("expected exactly one Example Data section followed by Dataset Information") return updated def update_docs( *, docs_root: Path, sample_size: int, seed: int, text_limit: int, dry_run: bool, ) -> list[Path]: changed: list[Path] = [] task_docs = sorted(path for path in docs_root.rglob("*.md") if path.name != "index.md") for path in task_docs: text = path.read_text(encoding="utf-8") if "## Example Data" not in text: continue reference = _task_reference_from_doc(path) table = load_example_table( dataset_id=reference.dataset_id, split_name=reference.split_name, sample_size=sample_size, seed=seed, text_limit=text_limit, ) updated = _replace_example_section(text, table) if updated == text: continue changed.append(path) if not dry_run: path.write_text(updated, encoding="utf-8") return changed def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description="Extract deterministic random query-positive examples from Nano benchmark datasets." ) parser.add_argument("dataset_id", nargs="?", help="Hugging Face dataset id, such as hakari-bench/NanoMMTEB-v2.") parser.add_argument("split_name", nargs="?", help="Dataset split/task name, such as argu_ana.") parser.add_argument("--queries-config", default="queries") parser.add_argument("--corpus-config", default="corpus") parser.add_argument("--qrels-config", default="qrels") parser.add_argument("--sample-size", type=int, default=DEFAULT_SAMPLE_SIZE) parser.add_argument("--seed", type=int, default=DEFAULT_SEED) parser.add_argument("--text-limit", type=int, default=DEFAULT_TEXT_LIMIT) parser.add_argument("--update-docs", type=Path, help="Replace Example Data sections below this docs root.") parser.add_argument("--dry-run", action="store_true", help="Report changed files without writing them.") return parser.parse_args(argv) def main(argv: Sequence[str] | None = None) -> int: args = parse_args(argv) if args.sample_size <= 0: raise SystemExit("--sample-size must be positive") if args.text_limit <= 0: raise SystemExit("--text-limit must be positive") if args.update_docs: changed = update_docs( docs_root=args.update_docs, sample_size=args.sample_size, seed=args.seed, text_limit=args.text_limit, dry_run=args.dry_run, ) action = "Would update" if args.dry_run else "Updated" for path in changed: print(path) print(f"{action} {len(changed)} files.", file=sys.stderr) return 0 if not args.dataset_id or not args.split_name: raise SystemExit("dataset_id and split_name are required unless --update-docs is used") print( load_example_table( dataset_id=args.dataset_id, split_name=args.split_name, queries_config=args.queries_config, corpus_config=args.corpus_config, qrels_config=args.qrels_config, sample_size=args.sample_size, seed=args.seed, text_limit=args.text_limit, ) ) return 0 if __name__ == "__main__": raise SystemExit(main())