hispath / data /camelyon.py
kohido's picture
init
8bf25c8
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from .base import BaseDatasetConfig, BaseDataModule
from torch.utils.data import Dataset, DataLoader
from typing import *
from dataclasses import dataclass, field
from PIL import Image
from utils import parse_structure
import os
import numpy as np
import torch
import albumentations as A
class CamelyonDataset(Dataset):
def __init__(self, root_dir: str, subset: str, image_size: Tuple[int, int]) -> None:
self.root_dir = root_dir
self.dataset = Camelyon17Dataset(root_dir=root_dir, download=True).get_subset(subset)
self.transform = {
"train" : A.Compose([
A.HorizontalFlip(),
A.Affine(scale=(-0.2, 0.2),
rotate=(-10, 10),
# shear=(-5, 5),
keep_ratio=True,
p=0.5),
A.OneOf([
A.MotionBlur(p=0.2),
A.MedianBlur(blur_limit=3, p=0.1),
A.Blur(blur_limit=3, p=0.1),
], p=0.5),
A.OneOf([
A.CLAHE(clip_limit=2),
A.RandomBrightnessContrast(),
], p=0.5),
A.HueSaturationValue(p=0.25),
A.Resize(image_size[0], image_size[1])
], p=1.0),
"val" : A.Compose([
A.Resize(image_size[0], image_size[1])
], p=1.0),
"test" : A.Compose([
A.Resize(image_size[0], image_size[1])
], p=1.0)
}[subset]
self.image_size = image_size
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
(image, label, _) = self.dataset.__getitem__(idx)
# image = image.resize(self.image_size)
image = np.array(image)
image = self.transform(image=image)["image"]
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
return image, label
class CamelyonDataModule(BaseDataModule):
cfg: BaseDatasetConfig
def __init__(self, cfg: BaseDatasetConfig) -> None:
super().__init__(cfg)
self.cfg:DatasetConfig = parse_structure(BaseDatasetConfig, cfg)
self.img_size = cfg.image_size
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = CamelyonDataset(self.cfg.data_source, "train", self.img_size)
if stage in [None, "fit", "validate"]:
self.val_dataset = CamelyonDataset(self.cfg.data_source, "val", self.img_size)
if stage in [None, "test", "predict"]:
self.test_dataset = CamelyonDataset(self.cfg.data_source, "test", self.img_size)