migratron / code_migration /split_dataset.py
amrithanandini's picture
integrated backend and frontend
1b35d41
"""Split timemachine-bench-verified.jsonl into train/eval with stratified difficulty.
Usage:
python code_migration/split_dataset.py [--eval-ratio 0.2] [--seed 42]
Output:
code_migration/data/train.jsonl
code_migration/data/eval.jsonl
"""
import json
import random
import argparse
from pathlib import Path
from collections import Counter
DATA_DIR = Path(__file__).parent / "data"
SOURCE = DATA_DIR / "timemachine-bench-verified.jsonl"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--eval-ratio", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# Load all tasks
tasks = []
with open(SOURCE) as f:
for line in f:
tasks.append(json.loads(line.strip()))
# Group by difficulty
by_difficulty = {}
for t in tasks:
d = t.get("difficulty", "Easy")
by_difficulty.setdefault(d, []).append(t)
random.seed(args.seed)
train, eval_ = [], []
print(f"Source: {SOURCE} ({len(tasks)} tasks)")
print(f"Eval ratio: {args.eval_ratio}, Seed: {args.seed}")
print()
for difficulty in ["Easy", "Medium", "Hard"]:
group = by_difficulty.get(difficulty, [])
random.shuffle(group)
n_eval = max(1, int(len(group) * args.eval_ratio))
n_train = len(group) - n_eval
eval_.extend(group[:n_eval])
train.extend(group[n_eval:])
print(f" {difficulty:8s}: {len(group):3d} total → {n_train:3d} train, {n_eval:3d} eval")
# Shuffle within splits
random.shuffle(train)
random.shuffle(eval_)
# Write
train_path = DATA_DIR / "train.jsonl"
eval_path = DATA_DIR / "eval.jsonl"
with open(train_path, "w") as f:
for t in train:
f.write(json.dumps(t, ensure_ascii=False) + "\n")
with open(eval_path, "w") as f:
for t in eval_:
f.write(json.dumps(t, ensure_ascii=False) + "\n")
print()
print(f"Train: {train_path} ({len(train)} tasks)")
print(f"Eval: {eval_path} ({len(eval_)} tasks)")
# Verify distribution
print()
print("Train distribution:")
for d, c in sorted(Counter(t["difficulty"] for t in train).items()):
print(f" {d}: {c}")
print("Eval distribution:")
for d, c in sorted(Counter(t["difficulty"] for t in eval_).items()):
print(f" {d}: {c}")
if __name__ == "__main__":
main()