Spaces:
Sleeping
Sleeping
| #Module for creating data loaders for training and validation datasets. | |
| from torch.utils.data import DataLoader, random_split | |
| from dataset import XRayDataset | |
| def get_dataloaders( | |
| csv_path, | |
| images_dir, | |
| batch_size=32, | |
| val_split=0.2 | |
| ): | |
| full_dataset = XRayDataset( | |
| csv_path=csv_path, | |
| images_dir=images_dir, | |
| train=True | |
| ) | |
| val_size = int(len(full_dataset) * val_split) | |
| train_size = len(full_dataset) - val_size | |
| train_ds, val_ds = random_split( | |
| full_dataset, | |
| [train_size, val_size] | |
| ) | |
| # Disable augmentation for validation dataset so that we only apply normalization | |
| val_ds.dataset.transform = XRayDataset( | |
| csv_path, | |
| images_dir, | |
| train=False | |
| ).transform | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=0 | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=0 | |
| ) | |
| return train_loader, val_loader | |