| import math |
| import torch |
| import torch.utils.data |
| from pathlib import Path |
| from torchvision import datasets, transforms |
| import multiprocessing |
|
|
| from .helpers import compute_mean_and_std, get_data_location |
| import matplotlib.pyplot as plt |
|
|
|
|
| def get_data_loaders( |
| batch_size: int = 32, valid_size: float = 0.2, num_workers: int = -1, limit: int = -1 |
| ): |
| """ |
| Create and returns the train_one_epoch, validation and test data loaders. |
| |
| :param batch_size: size of the mini-batches |
| :param valid_size: fraction of the dataset to use for validation. For example 0.2 |
| means that 20% of the dataset will be used for validation |
| :param num_workers: number of workers to use in the data loaders. Use -1 to mean |
| "use all my cores" |
| :param limit: maximum number of data points to consider |
| :return a dictionary with 3 keys: 'train_one_epoch', 'valid' and 'test' containing respectively the |
| train_one_epoch, validation and test data loaders |
| """ |
|
|
| if num_workers == -1: |
| |
| num_workers = multiprocessing.cpu_count() |
|
|
| |
| data_loaders = {"train": None, "valid": None, "test": None} |
|
|
| base_path = Path(get_data_location()) |
|
|
| |
| mean, std = compute_mean_and_std() |
| print(f"Dataset mean: {mean}, std: {std}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| data_transforms = { |
| "train": transforms.Compose( |
| [transforms.Resize(256), |
| transforms.RandomResizedCrop(224, scale=(0.8,1.0)), |
| transforms.RandAugment(2), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(mean.tolist(),std.tolist())] |
| ), |
| "valid": transforms.Compose( |
| [transforms.Resize((224,224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean.tolist(),std.tolist())] |
| ), |
| "test": transforms.Compose( |
| [transforms.Resize((224,224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean.tolist(),std.tolist())] |
| ), |
| } |
|
|
| |
| train_data = datasets.ImageFolder( |
| base_path / "train", |
| |
| |
| data_transforms["train"] |
| ) |
| |
| |
| valid_data = datasets.ImageFolder( |
| base_path / "train", |
| |
| |
| data_transforms["valid"] |
| ) |
|
|
| |
| n_tot = len(train_data) |
| indices = torch.randperm(n_tot) |
|
|
| |
| if limit > 0: |
| indices = indices[:limit] |
| n_tot = limit |
|
|
| split = int(math.ceil(valid_size * n_tot)) |
| train_idx, valid_idx = indices[split:], indices[:split] |
|
|
| |
| train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) |
| valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) |
|
|
| |
| data_loaders["train"] = torch.utils.data.DataLoader( |
| train_data, |
| batch_size=batch_size, |
| sampler=train_sampler, |
| num_workers=num_workers, |
| ) |
| data_loaders["valid"] = torch.utils.data.DataLoader( |
| |
| valid_data, |
| batch_size=batch_size, |
| sampler=valid_sampler, |
| num_workers=num_workers, |
| ) |
|
|
| |
| test_data = datasets.ImageFolder( |
| base_path / "test", |
| |
| data_transforms["test"] |
| ) |
|
|
| if limit > 0: |
| indices = torch.arange(limit) |
| test_sampler = torch.utils.data.SubsetRandomSampler(indices) |
| else: |
| test_sampler = None |
|
|
| data_loaders["test"] = torch.utils.data.DataLoader( |
| |
| test_data, |
| batch_size=batch_size, |
| sampler=test_sampler, |
| num_workers=num_workers, |
| shuffle=False, |
| ) |
|
|
| return data_loaders |
|
|
|
|
| def visualize_one_batch(data_loaders, max_n: int = 5): |
| """ |
| Visualize one batch of data. |
| |
| :param data_loaders: dictionary containing data loaders |
| :param max_n: maximum number of images to show |
| :return: None |
| """ |
|
|
| |
| |
| |
| dataiter = iter(data_loaders["train"]) |
| |
| |
| images, labels = next(dataiter) |
|
|
| |
| mean, std = compute_mean_and_std() |
| invTrans = transforms.Compose( |
| [ |
| transforms.Normalize(mean=[0.0, 0.0, 0.0], std=1 / std), |
| transforms.Normalize(mean=-mean, std=[1.0, 1.0, 1.0]), |
| ] |
| ) |
|
|
| images = invTrans(images) |
|
|
| |
| |
| class_names = data_loaders["train"].dataset.classes |
|
|
| |
| |
| images = torch.permute(images, (0, 2, 3, 1)).clip(0, 1) |
|
|
| |
| fig = plt.figure(figsize=(25, 4)) |
| for idx in range(max_n): |
| ax = fig.add_subplot(1, max_n, idx + 1, xticks=[], yticks=[]) |
| ax.imshow(images[idx]) |
| |
| |
| ax.set_title(class_names[labels[idx].item()]) |
|
|
|
|
| |
| |
| |
| import pytest |
|
|
|
|
| @pytest.fixture(scope="session") |
| def data_loaders(): |
| return get_data_loaders(batch_size=2, num_workers=0) |
|
|
|
|
| def test_data_loaders_keys(data_loaders): |
|
|
| assert set(data_loaders.keys()) == {"train", "valid", "test"}, "The keys of the data_loaders dictionary should be train, valid and test" |
|
|
|
|
| def test_data_loaders_output_type(data_loaders): |
| |
| dataiter = iter(data_loaders["train"]) |
| images, labels = next(dataiter) |
|
|
| assert isinstance(images, torch.Tensor), "images should be a Tensor" |
| assert isinstance(labels, torch.Tensor), "labels should be a Tensor" |
| assert images[0].shape[-1] == 224, "The tensors returned by your dataloaders should be 224x224. Did you " \ |
| "forget to resize and/or crop?" |
|
|
|
|
| def test_data_loaders_output_shape(data_loaders): |
| dataiter = iter(data_loaders["train"]) |
| images, labels = next(dataiter) |
|
|
| assert len(images) == 2, f"Expected a batch of size 2, got size {len(images)}" |
| assert ( |
| len(labels) == 2 |
| ), f"Expected a labels tensor of size 2, got size {len(labels)}" |
|
|
|
|
| def test_visualize_one_batch(data_loaders): |
|
|
| visualize_one_batch(data_loaders, max_n=2) |
|
|