File size: 619 Bytes
19c1f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from typing import List
import numpy as np
from sklearn.model_selection import KFold
def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]:
splits = []
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)):
train_keys = np.array(train_identifiers)[train_idx]
test_keys = np.array(train_identifiers)[test_idx]
splits.append({})
splits[-1]['train'] = list(train_keys)
splits[-1]['val'] = list(test_keys)
return splits
|