"""Split timemachine-bench-verified.jsonl into train/eval with stratified difficulty. Usage: python code_migration/split_dataset.py [--eval-ratio 0.2] [--seed 42] Output: code_migration/data/train.jsonl code_migration/data/eval.jsonl """ import json import random import argparse from pathlib import Path from collections import Counter DATA_DIR = Path(__file__).parent / "data" SOURCE = DATA_DIR / "timemachine-bench-verified.jsonl" def main(): parser = argparse.ArgumentParser() parser.add_argument("--eval-ratio", type=float, default=0.2) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # Load all tasks tasks = [] with open(SOURCE) as f: for line in f: tasks.append(json.loads(line.strip())) # Group by difficulty by_difficulty = {} for t in tasks: d = t.get("difficulty", "Easy") by_difficulty.setdefault(d, []).append(t) random.seed(args.seed) train, eval_ = [], [] print(f"Source: {SOURCE} ({len(tasks)} tasks)") print(f"Eval ratio: {args.eval_ratio}, Seed: {args.seed}") print() for difficulty in ["Easy", "Medium", "Hard"]: group = by_difficulty.get(difficulty, []) random.shuffle(group) n_eval = max(1, int(len(group) * args.eval_ratio)) n_train = len(group) - n_eval eval_.extend(group[:n_eval]) train.extend(group[n_eval:]) print(f" {difficulty:8s}: {len(group):3d} total → {n_train:3d} train, {n_eval:3d} eval") # Shuffle within splits random.shuffle(train) random.shuffle(eval_) # Write train_path = DATA_DIR / "train.jsonl" eval_path = DATA_DIR / "eval.jsonl" with open(train_path, "w") as f: for t in train: f.write(json.dumps(t, ensure_ascii=False) + "\n") with open(eval_path, "w") as f: for t in eval_: f.write(json.dumps(t, ensure_ascii=False) + "\n") print() print(f"Train: {train_path} ({len(train)} tasks)") print(f"Eval: {eval_path} ({len(eval_)} tasks)") # Verify distribution print() print("Train distribution:") for d, c in sorted(Counter(t["difficulty"] for t in train).items()): print(f" {d}: {c}") print("Eval distribution:") for d, c in sorted(Counter(t["difficulty"] for t in eval_).items()): print(f" {d}: {c}") if __name__ == "__main__": main()