Spaces:
Configuration error
Configuration error
File size: 2,859 Bytes
5000a45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
def _load_tasks(path: Path) -> Dict[str, List[dict]]:
grouped: Dict[str, List[dict]] = defaultdict(list)
with path.open("r", newline="", encoding="utf-8") as handle:
for row in csv.DictReader(handle):
grouped[row["task_id"]].append(row)
return grouped
def _task_n(task_id: str) -> int:
return int(task_id.split("_n")[-1])
def build_splits(grouped: Dict[str, List[dict]], heldout_family: str | None) -> Dict[str, object]:
tasks_by_family: Dict[str, List[str]] = defaultdict(list)
for task_id, rows in grouped.items():
tasks_by_family[rows[0]["family"]].append(task_id)
for family in tasks_by_family:
tasks_by_family[family].sort(key=_task_n)
families = sorted(tasks_by_family.keys())
if not families:
raise RuntimeError("No tasks found in measurement file.")
shape_train: List[str] = []
shape_test: List[str] = []
for family, tasks in tasks_by_family.items():
holdout_count = 2 if len(tasks) >= 4 else 1
split_idx = max(1, len(tasks) - holdout_count)
shape_train.extend(tasks[:split_idx])
shape_test.extend(tasks[split_idx:])
if heldout_family is None:
heldout_family = families[-1]
if heldout_family not in tasks_by_family:
raise ValueError(f"Held-out family {heldout_family} is not present.")
family_train = [task_id for family, tasks in tasks_by_family.items() if family != heldout_family for task_id in tasks]
family_test = list(tasks_by_family[heldout_family])
return {
"families_present": families,
"shape_generalization": {
"train_tasks": sorted(shape_train),
"test_tasks": sorted(shape_test),
},
"family_holdout": {
"heldout_family": heldout_family,
"train_tasks": sorted(family_train),
"test_tasks": sorted(family_test),
},
}
def main() -> None:
parser = argparse.ArgumentParser(description="Build train/test split manifests for the multi-family benchmark.")
parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv"))
parser.add_argument("--output", type=Path, default=Path("data/benchmark_splits.json"))
parser.add_argument("--heldout-family", type=str, default=None)
args = parser.parse_args()
splits = build_splits(_load_tasks(args.measurement_path), args.heldout_family)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as handle:
json.dump(splits, handle, indent=2)
print(json.dumps(splits, indent=2))
if __name__ == "__main__":
main()
|