Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import sys | |
| from typing import cast | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from hakari_bench.bm25 import BM25Config, BM25Tokenizer # noqa: E402 | |
| from hakari_bench.nano_dataset_builder import ( # noqa: E402 | |
| DEFAULT_BM25_TOP_K, | |
| DEFAULT_DOC_LIMIT, | |
| DEFAULT_QUERY_LIMIT, | |
| build_nano_dataset_from_hf_mteb, | |
| build_nano_dataset_from_local_source, | |
| ) | |
| TOKENIZERS = { | |
| "regex", | |
| "whitespace", | |
| "transformer", | |
| "stemmer", | |
| "english_regex", | |
| "english_porter", | |
| "english_porter_stop", | |
| "wordseg", | |
| } | |
| def parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Create a Nano-style retrieval dataset from MTEB/BEIR-style sources. " | |
| "The output uses corpus/queries/qrels/bm25 configs with task subsets as split names." | |
| ) | |
| ) | |
| source = parser.add_mutually_exclusive_group(required=True) | |
| source.add_argument("--source-dataset-id", help="Hugging Face MTEB/BEIR-style source dataset id.") | |
| source.add_argument("--source-dir", type=Path, help="Local parquet source root.") | |
| parser.add_argument("--dataset-name", required=True, help="Nano dataset name, e.g. NanoExample.") | |
| parser.add_argument("--dataset-id", required=True, help="Final Hugging Face dataset id, e.g. hakari-bench/NanoExample.") | |
| parser.add_argument("--split-name", required=True, help="Output Nano split/subset name.") | |
| parser.add_argument("--output-dir", type=Path, required=True, help="Output dataset directory.") | |
| parser.add_argument( | |
| "--dataset-config-dir", | |
| type=Path, | |
| default=None, | |
| help="Optional config/datasets directory where the HAKARI dataset YAML should be written.", | |
| ) | |
| parser.add_argument("--source-split-name", default=None, help="Local source split/task name.") | |
| parser.add_argument("--corpus-config", default="corpus") | |
| parser.add_argument("--queries-config", default="queries") | |
| parser.add_argument("--qrels-config", default="default") | |
| parser.add_argument("--corpus-split", default="corpus") | |
| parser.add_argument("--queries-split", default="queries") | |
| parser.add_argument("--qrels-split", default="test") | |
| parser.add_argument("--revision", default=None) | |
| parser.add_argument("--query-limit", type=int, default=DEFAULT_QUERY_LIMIT) | |
| parser.add_argument("--doc-limit", type=int, default=DEFAULT_DOC_LIMIT) | |
| parser.add_argument("--top-k", type=int, default=DEFAULT_BM25_TOP_K) | |
| parser.add_argument("--bm25-tokenizer", choices=sorted(TOKENIZERS), default="regex") | |
| parser.add_argument("--bm25-tokenizer-name", default=None) | |
| parser.add_argument("--bm25-stemmer-algorithm", default="english") | |
| parser.add_argument("--bm25-k1", type=float, default=1.5) | |
| parser.add_argument("--bm25-b", type=float, default=0.75) | |
| parser.add_argument("--show-progress", action="store_true") | |
| parser.add_argument( | |
| "--metadata-json", | |
| default=None, | |
| help="Optional JSON object for the generated HAKARI dataset YAML metadata.", | |
| ) | |
| return parser.parse_args(argv) | |
| def _metadata(raw: str | None) -> dict[str, object] | None: | |
| if raw is None: | |
| return None | |
| parsed = json.loads(raw) | |
| if not isinstance(parsed, dict): | |
| raise ValueError("--metadata-json must be a JSON object.") | |
| return cast(dict[str, object], parsed) | |
| def _bm25_config(args: argparse.Namespace) -> BM25Config: | |
| return BM25Config( | |
| tokenizer=cast(BM25Tokenizer, args.bm25_tokenizer), | |
| tokenizer_name=args.bm25_tokenizer_name, | |
| stemmer_algorithm=args.bm25_stemmer_algorithm, | |
| top_k=args.top_k, | |
| k1=args.bm25_k1, | |
| b=args.bm25_b, | |
| show_progress=args.show_progress, | |
| ) | |
| def main(argv: list[str] | None = None) -> None: | |
| args = parse_args(argv) | |
| metadata = _metadata(args.metadata_json) | |
| bm25_config = _bm25_config(args) | |
| if args.source_dir is not None: | |
| source_split_name = args.source_split_name or args.split_name | |
| result = build_nano_dataset_from_local_source( | |
| source_dir=args.source_dir, | |
| output_dir=args.output_dir, | |
| dataset_name=args.dataset_name, | |
| dataset_id=args.dataset_id, | |
| source_split_name=source_split_name, | |
| split_name=args.split_name, | |
| dataset_config_dir=args.dataset_config_dir, | |
| query_limit=args.query_limit, | |
| doc_limit=args.doc_limit, | |
| bm25_config=bm25_config, | |
| metadata=metadata, | |
| ) | |
| else: | |
| result = build_nano_dataset_from_hf_mteb( | |
| source_dataset_id=args.source_dataset_id, | |
| output_dir=args.output_dir, | |
| dataset_name=args.dataset_name, | |
| dataset_id=args.dataset_id, | |
| split_name=args.split_name, | |
| dataset_config_dir=args.dataset_config_dir, | |
| corpus_config=args.corpus_config, | |
| queries_config=args.queries_config, | |
| qrels_config=args.qrels_config, | |
| corpus_split=args.corpus_split, | |
| queries_split=args.queries_split, | |
| qrels_split=args.qrels_split, | |
| revision=args.revision, | |
| query_limit=args.query_limit, | |
| doc_limit=args.doc_limit, | |
| bm25_config=bm25_config, | |
| metadata=metadata, | |
| ) | |
| print(f"dataset={result.dataset_name}") | |
| print(f"split={result.split_name}") | |
| print(f"output_dir={result.output_dir}") | |
| print(f"queries={result.queries} corpus={result.corpus} qrels={result.qrels}") | |
| print(f"source_non_positive_qrels={result.source_non_positive_qrels}") | |
| print(f"forced_doc_count={result.forced_doc_count}") | |
| print(f"missing_positive_doc_count_after_forcing={result.missing_positive_doc_count_after_forcing}") | |
| print(f"bm25_ndcg_at_10={result.bm25_ndcg_at_10:.4f}") | |
| if result.dataset_config_path is not None: | |
| print(f"dataset_config={result.dataset_config_path}") | |
| if __name__ == "__main__": | |
| main() | |