| import os
|
| import json
|
| import numpy as np
|
| from sklearn.model_selection import KFold
|
|
|
|
|
|
|
|
|
| local_patients_id = [
|
| '101228', '101627', '102035', '102313', '104252', '104280', '104420',
|
| '104447', '104453', '104474', '104518', '104520', '104670', '104797',
|
| '104810', '104871', '104899', '104937', '105074', '105302', '105465',
|
| '105549', '105597', '105755', '105911', '105917', '105978', '106063',
|
| '106200', '106270', '106506', '106536', '106639', '106780', '106905',
|
| '106976', '107130', '107233', '107455', '107508', '107539', '107630',
|
| '107680', '107739', '107966', '107997', '108295', '108344', '108444',
|
| '108726', '108807', '108975', '109141', '109267', '109395', '109654',
|
| '109816', '109923', '109944', '110012', '110157', '110218', '110280',
|
| '110327', '110497', '110540', '110543', '110784', '111008', '111140',
|
| '111189', '111489', '111691', '111852', '112055', '112378', '112414',
|
| '112657', '112659', '112730', '112765', '112776', '112997', '113046',
|
| '113394', '113845', '114058', '114128', '114266', '114304', '114454',
|
| '114525', '114585', '114770', '114836', '114903', '114990', '115588',
|
| '115628', '115788',
|
| ]
|
|
|
| public_patients_id = [
|
| 'c01p01', 'c01p02', 'c01p03', 'c01p04', 'c01p05',
|
| 'c07p01', 'c07p02', 'c07p03', 'c07p04', 'c07p05',
|
| 'c08p01', 'c08p02', 'c08p03', 'c08p04', 'c08p05',
|
| ]
|
|
|
| RANDOM_SEED = 42
|
| N_FOLDS = 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def make_folds_exact(trainval, n_val_per_fold, n_folds, rng):
|
| arr = np.array(trainval)
|
| rng.shuffle(arr)
|
|
|
| total_val_pool = n_folds * n_val_per_fold
|
| assert total_val_pool <= len(arr), (
|
| f"Not enough trainval ({len(arr)}) for {n_folds} x {n_val_per_fold} val = {total_val_pool}"
|
| )
|
| val_pool = arr[:total_val_pool]
|
| train_base = arr[total_val_pool:]
|
|
|
| folds = {}
|
| for fold_idx in range(n_folds):
|
| val_pts = val_pool[fold_idx * n_val_per_fold:(fold_idx + 1) * n_val_per_fold].tolist()
|
| other_val = np.concatenate([
|
| val_pool[:fold_idx * n_val_per_fold],
|
| val_pool[(fold_idx + 1) * n_val_per_fold:]
|
| ])
|
| train_pts = np.concatenate([other_val, train_base]).tolist()
|
| folds[f"fold_{fold_idx}"] = {
|
| "train_patients": train_pts,
|
| "val_patients": val_pts,
|
| "n_train": len(train_pts),
|
| "n_val": len(val_pts),
|
| }
|
| return folds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def make_folds_kfold(trainval, n_folds, rng):
|
| arr = np.array(trainval)
|
| rng.shuffle(arr)
|
|
|
| kf = KFold(n_splits=n_folds, shuffle=False)
|
| folds = {}
|
| for fold_idx, (train_idx, val_idx) in enumerate(kf.split(arr)):
|
| folds[f"fold_{fold_idx}"] = {
|
| "train_patients": arr[train_idx].tolist(),
|
| "val_patients": arr[val_idx].tolist(),
|
| "n_train": len(train_idx),
|
| "n_val": len(val_idx),
|
| }
|
| return folds
|
|
|
|
|
|
|
|
|
|
|
|
|
| n_local = len(local_patients_id)
|
| n_local_test = round(n_local * 0.20)
|
| n_local_val_per_fold = round(n_local * 0.10)
|
|
|
| rng_local = np.random.default_rng(RANDOM_SEED)
|
| local_arr = np.array(local_patients_id)
|
| rng_local.shuffle(local_arr)
|
|
|
| local_test = local_arr[:n_local_test].tolist()
|
| local_trainval = local_arr[n_local_test:].tolist()
|
|
|
| local_folds = make_folds_exact(
|
| local_trainval,
|
| n_val_per_fold=n_local_val_per_fold,
|
| n_folds=N_FOLDS,
|
| rng=np.random.default_rng(RANDOM_SEED + 1),
|
| )
|
|
|
| local_split = {
|
| "metadata": {
|
| "dataset": "Local_SAI",
|
| "total_patients": n_local,
|
| "test_patients": n_local_test,
|
| "trainval_patients": len(local_trainval),
|
| "target_split": "70/10/20 (train/val/test)",
|
| "exact_counts": "train=69, val=10, test=20 per fold",
|
| "n_folds": N_FOLDS,
|
| "random_seed": RANDOM_SEED,
|
| },
|
| "test_set": {"patients": local_test, "n_patients": n_local_test},
|
| "folds": local_folds,
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| n_public = len(public_patients_id)
|
|
|
|
|
| centers = {}
|
| for pid in public_patients_id:
|
| centers.setdefault(pid[:3], []).append(pid)
|
|
|
| public_test = []
|
| public_trainval = []
|
| for center, pids in sorted(centers.items()):
|
| arr = np.array(pids)
|
| np.random.default_rng(RANDOM_SEED + hash(center) % 1000).shuffle(arr)
|
| public_test.append(arr[0])
|
| public_trainval += arr[1:].tolist()
|
|
|
| public_folds = make_folds_kfold(
|
| public_trainval,
|
| n_folds=N_FOLDS,
|
| rng=np.random.default_rng(RANDOM_SEED + 2),
|
| )
|
|
|
| public_split = {
|
| "metadata": {
|
| "dataset": "Public_MSSEG",
|
| "total_patients": n_public,
|
| "test_patients": len(public_test),
|
| "trainval_patients": len(public_trainval),
|
| "target_split": "60/20/20 (train/val/test)",
|
| "n_folds": N_FOLDS,
|
| "random_seed": RANDOM_SEED,
|
| "center_balanced_test": True,
|
| },
|
| "test_set": {"patients": public_test, "n_patients": len(public_test)},
|
| "folds": public_folds,
|
| }
|
|
|
|
|
|
|
|
|
| concat_test = local_test + public_test
|
| concat_folds = {}
|
| for fold_key in local_folds:
|
| lf = local_folds[fold_key]
|
| pf = public_folds[fold_key]
|
| concat_folds[fold_key] = {
|
| "train_patients": lf["train_patients"] + pf["train_patients"],
|
| "val_patients": lf["val_patients"] + pf["val_patients"],
|
| "n_train": lf["n_train"] + pf["n_train"],
|
| "n_val": lf["n_val"] + pf["n_val"],
|
| }
|
|
|
| concat_split = {
|
| "metadata": {
|
| "datasets": ["Local_SAI", "Public_MSSEG"],
|
| "total_patients": n_local + n_public,
|
| "test_patients": len(concat_test),
|
| "trainval_patients": len(local_trainval) + len(public_trainval),
|
| "local_split": "70/10/20",
|
| "public_split": "60/20/20",
|
| "n_folds": N_FOLDS,
|
| "random_seed": RANDOM_SEED,
|
| },
|
| "test_set": {"patients": concat_test, "n_patients": len(concat_test)},
|
| "folds": concat_folds,
|
| }
|
|
|
|
|
|
|
|
|
| output_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
| for name, data in [
|
| ("local_fold_assignments.json", local_split),
|
| ("public_fold_assignments.json", public_split),
|
| ("concat_fold_assignments.json", concat_split),
|
| ]:
|
| path = os.path.join(output_dir, name)
|
| with open(path, "w") as f:
|
| json.dump(data, f, indent=2)
|
| print(f"Saved: {path}")
|
|
|
|
|
|
|
|
|
| print("\n=== SANITY CHECK ===")
|
| for label, split_data in [("LOCAL", local_split), ("PUBLIC", public_split), ("CONCAT", concat_split)]:
|
| test_pts = set(split_data["test_set"]["patients"])
|
| print(f"\n{label} (test={len(test_pts)})")
|
| val_sets = []
|
| for fold_key, fold in split_data["folds"].items():
|
| train_pts = set(fold["train_patients"])
|
| val_pts = set(fold["val_patients"])
|
| val_sets.append(val_pts)
|
| tv_overlap = len(train_pts & val_pts)
|
| tst_overlap = len((train_pts | val_pts) & test_pts)
|
| print(f" {fold_key}: train={len(train_pts):3d}, val={len(val_pts):2d} | "
|
| f"train/val overlap={tv_overlap} | (train+val)/test overlap={tst_overlap}")
|
| bad = [f"f{i}&f{j}" for i in range(len(val_sets)) for j in range(i+1, len(val_sets)) if val_sets[i] & val_sets[j]]
|
| print(f" Val sets unique across folds: {'FAIL: ' + str(bad) if bad else 'OK'}") |