xray-classification / training /dataloader.py
Flamekizer11's picture
Upload 27 files
64d0ccc verified
#Module for creating data loaders for training and validation datasets.
from torch.utils.data import DataLoader, random_split
from dataset import XRayDataset
def get_dataloaders(
csv_path,
images_dir,
batch_size=32,
val_split=0.2
):
full_dataset = XRayDataset(
csv_path=csv_path,
images_dir=images_dir,
train=True
)
val_size = int(len(full_dataset) * val_split)
train_size = len(full_dataset) - val_size
train_ds, val_ds = random_split(
full_dataset,
[train_size, val_size]
)
# Disable augmentation for validation dataset so that we only apply normalization
val_ds.dataset.transform = XRayDataset(
csv_path,
images_dir,
train=False
).transform
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
return train_loader, val_loader