| |
| """ |
| MINDI 1.5 Vision-Coder β Train / Validation / Test Split |
| |
| Splits mindi_filtered.jsonl into: |
| - train.jsonl (90%) |
| - val.jsonl (5%) |
| - test.jsonl (5%) |
| |
| Stratified by source to ensure proportional representation. |
| Deterministic with a fixed random seed. |
| |
| Usage: |
| python scripts/split_data.py # Default 90/5/5 |
| python scripts/split_data.py --train 0.85 --val 0.10 --test 0.05 |
| python scripts/split_data.py --seed 42 |
| python scripts/split_data.py --dry-run |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import random |
| import sys |
| import time |
| from collections import Counter |
| from pathlib import Path |
|
|
| |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| INPUT_FILE = PROJECT_ROOT / "data" / "processed" / "mindi_filtered.jsonl" |
| OUTPUT_DIR = PROJECT_ROOT / "data" / "processed" |
| TRAIN_FILE = OUTPUT_DIR / "train.jsonl" |
| VAL_FILE = OUTPUT_DIR / "val.jsonl" |
| TEST_FILE = OUTPUT_DIR / "test.jsonl" |
|
|
|
|
| def run_split( |
| train_ratio: float = 0.90, |
| val_ratio: float = 0.05, |
| test_ratio: float = 0.05, |
| seed: int = 42, |
| dry_run: bool = False, |
| ) -> None: |
| """Split filtered data into train/val/test with stratification by source.""" |
|
|
| |
| total_ratio = train_ratio + val_ratio + test_ratio |
| if abs(total_ratio - 1.0) > 0.001: |
| print(f"ERROR: Ratios must sum to 1.0, got {total_ratio:.3f}") |
| sys.exit(1) |
|
|
| if not INPUT_FILE.exists(): |
| print(f"ERROR: Input file not found: {INPUT_FILE}") |
| print(" Run quality_filter.py first to generate mindi_filtered.jsonl") |
| sys.exit(1) |
|
|
| print(f"Loading examples from {INPUT_FILE.name} ...") |
| start = time.time() |
|
|
| |
| source_lines: dict[str, list[str]] = {} |
| total = 0 |
| with open(INPUT_FILE, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| total += 1 |
| try: |
| example = json.loads(line) |
| source = example.get("source", "unknown") |
| except json.JSONDecodeError: |
| source = "unknown" |
| source_lines.setdefault(source, []).append(line) |
|
|
| load_time = time.time() - start |
| print(f" Loaded {total:,} examples in {load_time:.1f}s") |
| print(f" Sources: {len(source_lines)}") |
| print() |
|
|
| |
| print(f"Split ratios: train={train_ratio:.0%} / val={val_ratio:.0%} / test={test_ratio:.0%}") |
| print(f"Random seed: {seed}") |
| print(f"Dry run: {dry_run}") |
| print() |
|
|
| rng = random.Random(seed) |
|
|
| train_lines: list[str] = [] |
| val_lines: list[str] = [] |
| test_lines: list[str] = [] |
|
|
| source_stats: dict[str, dict[str, int]] = {} |
|
|
| for source in sorted(source_lines.keys()): |
| lines = source_lines[source] |
| rng.shuffle(lines) |
|
|
| n = len(lines) |
| n_val = max(1, round(n * val_ratio)) if n >= 3 else 0 |
| n_test = max(1, round(n * test_ratio)) if n >= 3 else 0 |
| n_train = n - n_val - n_test |
|
|
| |
| if n < 3: |
| n_train = n |
| n_val = 0 |
| n_test = 0 |
|
|
| train_lines.extend(lines[:n_train]) |
| val_lines.extend(lines[n_train:n_train + n_val]) |
| test_lines.extend(lines[n_train + n_val:]) |
|
|
| source_stats[source] = { |
| "total": n, |
| "train": n_train, |
| "val": n_val, |
| "test": n_test, |
| } |
|
|
| |
| rng.shuffle(train_lines) |
| rng.shuffle(val_lines) |
| rng.shuffle(test_lines) |
|
|
| |
| print("=" * 60) |
| print(" SPLIT SUMMARY") |
| print("=" * 60) |
| print(f" Total: {total:>10,}") |
| print(f" Train: {len(train_lines):>10,} ({len(train_lines)/total*100:.1f}%)") |
| print(f" Validation: {len(val_lines):>10,} ({len(val_lines)/total*100:.1f}%)") |
| print(f" Test: {len(test_lines):>10,} ({len(test_lines)/total*100:.1f}%)") |
| print() |
|
|
| print(" Per-source breakdown:") |
| print(f" {'Source':<25s} {'Total':>8s} {'Train':>8s} {'Val':>8s} {'Test':>8s}") |
| print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8}") |
| for source in sorted(source_stats.keys()): |
| s = source_stats[source] |
| print(f" {source:<25s} {s['total']:>8,} {s['train']:>8,} {s['val']:>8,} {s['test']:>8,}") |
| print() |
|
|
| if not dry_run: |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| print("Writing files ...") |
| for path, lines, name in [ |
| (TRAIN_FILE, train_lines, "train"), |
| (VAL_FILE, val_lines, "val"), |
| (TEST_FILE, test_lines, "test"), |
| ]: |
| with open(path, "w", encoding="utf-8") as f: |
| for line in lines: |
| f.write(line + "\n") |
| size_mb = path.stat().st_size / (1024 * 1024) |
| print(f" {name:<12s} β {path.name:<20s} ({len(lines):>10,} examples, {size_mb:>8.1f} MB)") |
|
|
| |
| meta = { |
| "total": total, |
| "train_count": len(train_lines), |
| "val_count": len(val_lines), |
| "test_count": len(test_lines), |
| "train_pct": round(len(train_lines) / total * 100, 2), |
| "val_pct": round(len(val_lines) / total * 100, 2), |
| "test_pct": round(len(test_lines) / total * 100, 2), |
| "seed": seed, |
| "source_breakdown": source_stats, |
| } |
| meta_path = OUTPUT_DIR / "split_meta.json" |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(meta, f, indent=2) |
| print(f" Metadata β {meta_path.name}") |
|
|
| elapsed = time.time() - start |
| print(f"\n Done in {elapsed:.1f}s") |
| print("=" * 60) |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="MINDI Data Splitter β stratified train/val/test split", |
| ) |
| parser.add_argument("--train", type=float, default=0.90, help="Train ratio (default: 0.90)") |
| parser.add_argument("--val", type=float, default=0.05, help="Validation ratio (default: 0.05)") |
| parser.add_argument("--test", type=float, default=0.05, help="Test ratio (default: 0.05)") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") |
| parser.add_argument("--dry-run", action="store_true", help="Preview split without writing files") |
|
|
| args = parser.parse_args() |
| run_split( |
| train_ratio=args.train, |
| val_ratio=args.val, |
| test_ratio=args.test, |
| seed=args.seed, |
| dry_run=args.dry_run, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|