from wilds.datasets.camelyon17_dataset import Camelyon17Dataset from .base import BaseDatasetConfig, BaseDataModule from torch.utils.data import Dataset, DataLoader from typing import * from dataclasses import dataclass, field from PIL import Image from utils import parse_structure import os import numpy as np import torch import albumentations as A class CamelyonDataset(Dataset): def __init__(self, root_dir: str, subset: str, image_size: Tuple[int, int]) -> None: self.root_dir = root_dir self.dataset = Camelyon17Dataset(root_dir=root_dir, download=True).get_subset(subset) self.transform = { "train" : A.Compose([ A.HorizontalFlip(), A.Affine(scale=(-0.2, 0.2), rotate=(-10, 10), # shear=(-5, 5), keep_ratio=True, p=0.5), A.OneOf([ A.MotionBlur(p=0.2), A.MedianBlur(blur_limit=3, p=0.1), A.Blur(blur_limit=3, p=0.1), ], p=0.5), A.OneOf([ A.CLAHE(clip_limit=2), A.RandomBrightnessContrast(), ], p=0.5), A.HueSaturationValue(p=0.25), A.Resize(image_size[0], image_size[1]) ], p=1.0), "val" : A.Compose([ A.Resize(image_size[0], image_size[1]) ], p=1.0), "test" : A.Compose([ A.Resize(image_size[0], image_size[1]) ], p=1.0) }[subset] self.image_size = image_size def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: (image, label, _) = self.dataset.__getitem__(idx) # image = image.resize(self.image_size) image = np.array(image) image = self.transform(image=image)["image"] image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 return image, label class CamelyonDataModule(BaseDataModule): cfg: BaseDatasetConfig def __init__(self, cfg: BaseDatasetConfig) -> None: super().__init__(cfg) self.cfg:DatasetConfig = parse_structure(BaseDatasetConfig, cfg) self.img_size = cfg.image_size def setup(self, stage=None) -> None: if stage in [None, "fit"]: self.train_dataset = CamelyonDataset(self.cfg.data_source, "train", self.img_size) if stage in [None, "fit", "validate"]: self.val_dataset = CamelyonDataset(self.cfg.data_source, "val", self.img_size) if stage in [None, "test", "predict"]: self.test_dataset = CamelyonDataset(self.cfg.data_source, "test", self.img_size)