File size: 1,763 Bytes
ae51a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import argparse
from glob import glob
import math
import random

def copy(src, dst_dir):
    os.makedirs(dst_dir, exist_ok=True)
    with open(src, "rb") as f:
        with open(os.path.join(dst_dir, os.path.basename(src)), "wb") as fw:
            fw.write(f.read())

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mild_data")
    parser.add_argument("--moderate_data")
    parser.add_argument("--non_data")
    parser.add_argument("--very_mild_data")
    parser.add_argument("--training_data_output")
    parser.add_argument("--testing_data_output")
    parser.add_argument("--split_size", type=int)
    args = parser.parse_args()

    class_dirs = {
        "MildDemented": args.mild_data,
        "ModerateDemented": args.moderate_data,
        "NonDemented": args.non_data,
        "VeryMildDemented": args.very_mild_data,
    }

    test_ratio = args.split_size / 100

    for class_name, folder in class_dirs.items():
        print(f"\nProcessing {class_name} at {folder}")

        images = []
        for ext in ("*.jpg", "*.jpeg", "*.png"):
            images.extend(glob(os.path.join(folder, "**", ext), recursive=True))

        print(f"Found {len(images)} images")

        if not images:
            continue

        random.shuffle(images)
        n_test = math.ceil(len(images) * test_ratio)

        test_files = images[:n_test]
        train_files = images[n_test:]

        train_out = os.path.join(args.training_data_output, class_name)
        test_out = os.path.join(args.testing_data_output, class_name)

        for f in test_files:
            copy(f, test_out)
        for f in train_files:
            copy(f, train_out)

    print("\n✔ Split complete")

if __name__ == "__main__":
    main()