import os import argparse from glob import glob import math import random def copy(src, dst_dir): os.makedirs(dst_dir, exist_ok=True) with open(src, "rb") as f: with open(os.path.join(dst_dir, os.path.basename(src)), "wb") as fw: fw.write(f.read()) def main(): parser = argparse.ArgumentParser() parser.add_argument("--mild_data") parser.add_argument("--moderate_data") parser.add_argument("--non_data") parser.add_argument("--very_mild_data") parser.add_argument("--training_data_output") parser.add_argument("--testing_data_output") parser.add_argument("--split_size", type=int) args = parser.parse_args() class_dirs = { "MildDemented": args.mild_data, "ModerateDemented": args.moderate_data, "NonDemented": args.non_data, "VeryMildDemented": args.very_mild_data, } test_ratio = args.split_size / 100 for class_name, folder in class_dirs.items(): print(f"\nProcessing {class_name} at {folder}") images = [] for ext in ("*.jpg", "*.jpeg", "*.png"): images.extend(glob(os.path.join(folder, "**", ext), recursive=True)) print(f"Found {len(images)} images") if not images: continue random.shuffle(images) n_test = math.ceil(len(images) * test_ratio) test_files = images[:n_test] train_files = images[n_test:] train_out = os.path.join(args.training_data_output, class_name) test_out = os.path.join(args.testing_data_output, class_name) for f in test_files: copy(f, test_out) for f in train_files: copy(f, train_out) print("\n✔ Split complete") if __name__ == "__main__": main()