Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List | |
| def _load_tasks(path: Path) -> Dict[str, List[dict]]: | |
| grouped: Dict[str, List[dict]] = defaultdict(list) | |
| with path.open("r", newline="", encoding="utf-8") as handle: | |
| for row in csv.DictReader(handle): | |
| grouped[row["task_id"]].append(row) | |
| return grouped | |
| def _task_n(task_id: str) -> int: | |
| return int(task_id.split("_n")[-1]) | |
| def build_splits(grouped: Dict[str, List[dict]], heldout_family: str | None) -> Dict[str, object]: | |
| tasks_by_family: Dict[str, List[str]] = defaultdict(list) | |
| for task_id, rows in grouped.items(): | |
| tasks_by_family[rows[0]["family"]].append(task_id) | |
| for family in tasks_by_family: | |
| tasks_by_family[family].sort(key=_task_n) | |
| families = sorted(tasks_by_family.keys()) | |
| if not families: | |
| raise RuntimeError("No tasks found in measurement file.") | |
| shape_train: List[str] = [] | |
| shape_test: List[str] = [] | |
| for family, tasks in tasks_by_family.items(): | |
| holdout_count = 2 if len(tasks) >= 4 else 1 | |
| split_idx = max(1, len(tasks) - holdout_count) | |
| shape_train.extend(tasks[:split_idx]) | |
| shape_test.extend(tasks[split_idx:]) | |
| if heldout_family is None: | |
| heldout_family = families[-1] | |
| if heldout_family not in tasks_by_family: | |
| raise ValueError(f"Held-out family {heldout_family} is not present.") | |
| family_train = [task_id for family, tasks in tasks_by_family.items() if family != heldout_family for task_id in tasks] | |
| family_test = list(tasks_by_family[heldout_family]) | |
| return { | |
| "families_present": families, | |
| "shape_generalization": { | |
| "train_tasks": sorted(shape_train), | |
| "test_tasks": sorted(shape_test), | |
| }, | |
| "family_holdout": { | |
| "heldout_family": heldout_family, | |
| "train_tasks": sorted(family_train), | |
| "test_tasks": sorted(family_test), | |
| }, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Build train/test split manifests for the multi-family benchmark.") | |
| parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv")) | |
| parser.add_argument("--output", type=Path, default=Path("data/benchmark_splits.json")) | |
| parser.add_argument("--heldout-family", type=str, default=None) | |
| args = parser.parse_args() | |
| splits = build_splits(_load_tasks(args.measurement_path), args.heldout_family) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| with args.output.open("w", encoding="utf-8") as handle: | |
| json.dump(splits, handle, indent=2) | |
| print(json.dumps(splits, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |