File size: 3,173 Bytes
604568a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import shutil
import random
import argparse

class DataSplitter:
    """Class to split Nii or nii.gz files into training and validation datasets."""

    def __init__(self, input_folder, train_ratio, train_folder, val_folder):
        """
        Initialize the DataSplitter with folder paths and data ratio.

        :param input_folder: Path to the folder containing the Nii files.
        :param train_ratio: Ratio of the dataset to be used as training data.
        :param train_folder: Name of the folder to store training data.
        :param val_folder: Name of the folder to store validation data.
        """
        self.input_folder = input_folder
        self.train_ratio = train_ratio
        self.train_folder = os.path.join(input_folder, train_folder)
        self.val_folder = os.path.join(input_folder, val_folder)

    def split_data(self):
        """
        Split the data into training and validation datasets and move files
        to the respective folders.
        """
        # List all Nii or nii.gz files in the input folder
        files = [f for f in os.listdir(self.input_folder)
                 if f.endswith('.nii') or f.endswith('.nii.gz')]
        random.shuffle(files)  # Shuffle files for random splitting

        # Determine the split index based on the training ratio
        split_index = int(len(files) * self.train_ratio)
        train_files = files[:split_index]
        val_files = files[split_index:]

        # Create training and validation folders if they don't exist
        os.makedirs(self.train_folder, exist_ok=True)
        os.makedirs(self.val_folder, exist_ok=True)

        # Move files to the respective folders
        for file in train_files:
            shutil.move(os.path.join(self.input_folder, file),
                        self.train_folder)

        for file in val_files:
            shutil.move(os.path.join(self.input_folder, file),
                        self.val_folder)

        print(f"Files split into {len(train_files)} training "
              f"and {len(val_files)} validation.")

def main():
    """
    Main function to handle command line arguments and initiate data splitting.
    """
    parser = argparse.ArgumentParser(
        description='Split Nii files into training and validation datasets.')

    # Define command line arguments
    parser.add_argument('--input', type=str, required=True,
                        help='Input folder with Nii files.')
    parser.add_argument('--train-ratio', type=float, default=0.8,
                        help='Training data ratio (default: 0.8).')
    parser.add_argument('--train-folder', type=str, default='training',
                        help='Folder for training data (default: "training").')
    parser.add_argument('--val-folder', type=str, default='validation',
                        help='Folder for validation data (default: "validation").')

    args = parser.parse_args()

    # Create a DataSplitter instance and split the data
    splitter = DataSplitter(args.input, args.train_ratio,
                            args.train_folder, args.val_folder)
    splitter.split_data()

if __name__ == "__main__":
    main()