#Module for creating data loaders for training and validation datasets. from torch.utils.data import DataLoader, random_split from dataset import XRayDataset def get_dataloaders( csv_path, images_dir, batch_size=32, val_split=0.2 ): full_dataset = XRayDataset( csv_path=csv_path, images_dir=images_dir, train=True ) val_size = int(len(full_dataset) * val_split) train_size = len(full_dataset) - val_size train_ds, val_ds = random_split( full_dataset, [train_size, val_size] ) # Disable augmentation for validation dataset so that we only apply normalization val_ds.dataset.transform = XRayDataset( csv_path, images_dir, train=False ).transform train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=0 ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=0 ) return train_loader, val_loader