File size: 5,201 Bytes
8bf25c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torch.utils.data import Dataset, DataLoader
from typing import *
from dataclasses import dataclass, field
from PIL import Image
from utils import parse_structure
from glob import glob
from random import shuffle
from torchvision.transforms import v2

import os
import lightning.pytorch as pl
import numpy as np
import torch
import random


class BreakhisDataset(Dataset):
    def __init__(self, root_dir: str, image_size: Tuple[int, int], subset: str, aug: dict = None) -> None:
        self.root_dir = root_dir
        self.image_size = image_size
        self.classes = {
            'benign' : 0,
            'malignant' : 1
        }
        self.ratio = [0.8, 0.1]
        self.subset = subset
        self.aug = aug

        self.benign_subclasses = ['adenosis', 'fibroadenoma', 'phyllodes_tumor', 'tubular_adenoma']
        self.malignant_subclasses = ['ductal_carcinoma', 'lobular_carcinoma', 'mucinous_carcinoma', 'papillary_carcinoma']
        self.cls2sublst = {
            'benign' : self.benign_subclasses,
            'malignant' : self.malignant_subclasses
        }
        self.factors = ['100X', '200X', '400X', '40X']

        self.sample_paths = []
        self.sample_labels = []

        random.seed(42)
        for cate in ['benign', 'malignant']:
            for subcls in self.cls2sublst[cate]:
                for factor in self.factors:
                    lst = glob(os.path.join(self.root_dir, f'{cate}/*/{subcls}/*/{factor}/*.png'))
                    random.shuffle(lst)

                    sublst = self.get_subset(lst)
                    self.sample_paths += sublst
                    self.sample_labels += [self.classes[cate]] * len(sublst)
        
        if self.aug is not None:
            self.transforms = [v2.Resize(self.image_size, antialias=True)] + \
                [getattr(v2, x)(**self.aug[x]) for x in self.aug] + \
                [
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
        else:
            self.transforms = [
                v2.Resize(self.image_size, antialias=True),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        
        self.transform = v2.Compose(self.transforms)
    
    def get_subset(self, x: list):
        if self.subset == 'train':
            return x[ : int(self.ratio[0] * len(x))]
        elif self.subset == 'valid':
            return x[int(self.ratio[0] * len(x)) : int((self.ratio[0] + self.ratio[1]) * len(x))]
        elif self.subset == 'test':
            return x[int((self.ratio[0] + self.ratio[1]) * len(x)) : ]
        else:
            return ValueError('Unknown subset')

    def __len__(self) -> int:
        return len(self.sample_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_path = self.sample_paths[idx]
        label = self.sample_labels[idx]
        image = Image.open(img_path).convert("RGB")
        image = image.resize(self.image_size)
        image = np.array(image)
        image = torch.from_numpy(image).permute(2, 0, 1)
        image = self.transform(image)

        return image, label


@dataclass
class BaseDatasetConfig:
    data_source: str = ''
    batch_size:int = 32
    shuffle:bool = True
    num_workers:int = 24
    image_size:Tuple[int, int] = (224, 224)
    aug: dict = field(default_factory=dict)

class BreakhisDataModule(pl.LightningDataModule):
    cfg: BaseDatasetConfig

    def __init__(self, cfg: BaseDatasetConfig) -> None:
        super().__init__()
        self.cfg:BaseDatasetConfig = parse_structure(BaseDatasetConfig, cfg)
        self.data_source = self.cfg.data_source
        self.img_size = self.cfg.image_size
        self.aug = self.cfg.aug

    def setup(self, stage=None) -> None:
        if stage in [None, "fit"]:
            self.train_dataset = BreakhisDataset(self.data_source, self.img_size, 'train', self.aug)
        if stage in [None, "fit", "validate"]:
            self.val_dataset = BreakhisDataset(self.data_source, self.img_size, 'valid', self.aug)
        if stage in [None, "test", "predict"]:
            self.test_dataset = BreakhisDataset(self.data_source, self.img_size, 'test', self.aug)

    def general_loader(self, dataset, batch_size) -> DataLoader:
        return DataLoader(
            dataset, 
            num_workers=self.cfg.num_workers, 
            batch_size=batch_size
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset, 
            num_workers=self.cfg.num_workers, 
            batch_size=self.cfg.batch_size, 
            shuffle=self.cfg.shuffle
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset, 
            num_workers=self.cfg.num_workers, 
            batch_size=self.cfg.batch_size
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset, 
            num_workers=self.cfg.num_workers, 
            batch_size=self.cfg.batch_size
        )