kohido commited on
Commit
8bf25c8
·
1 Parent(s): 507e3c1
data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import BaseDataModule
2
+ from .camelyon import CamelyonDataModule
3
+ from .breakhis import BreakhisDataModule
data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (302 Bytes). View file
 
data/__pycache__/base.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
data/__pycache__/breakhis.cpython-310.pyc ADDED
Binary file (5.35 kB). View file
 
data/__pycache__/camelyon.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
data/base.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ from typing import *
3
+ from dataclasses import dataclass, field
4
+ from PIL import Image
5
+ from utils import parse_structure
6
+
7
+ import os
8
+ import lightning.pytorch as pl
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ class BaseDataset(Dataset):
14
+ def __init__(self, root_dir: str, image_size: Tuple[int, int]) -> None:
15
+ self.root_dir = root_dir
16
+ self.image_size = image_size
17
+ self.classes = {folder: idx for idx, folder in enumerate(os.listdir(root_dir))}
18
+ self.image_paths = []
19
+ self.labels = []
20
+
21
+ for class_name, class_idx in self.classes.items():
22
+ class_dir = os.path.join(root_dir, class_name)
23
+ for img_name in os.listdir(class_dir):
24
+ img_path = os.path.join(class_dir, img_name)
25
+ self.image_paths.append(img_path)
26
+ self.labels.append(class_idx)
27
+
28
+ def __len__(self) -> int:
29
+ return len(self.image_paths)
30
+
31
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
32
+ img_path = self.image_paths[idx]
33
+ label = self.labels[idx]
34
+ image = Image.open(img_path).convert("RGB")
35
+ image = image.resize(self.image_size)
36
+ image = np.array(image)
37
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
38
+ return image, label
39
+
40
+
41
+ @dataclass
42
+ class BaseDatasetConfig:
43
+ data_source: str = ''
44
+ train_path:str = ''
45
+ valid_path:str = ''
46
+ test_path:str = ''
47
+ batch_size:int = 32
48
+ shuffle:bool = True
49
+ num_workers:int = 24
50
+ image_size:Tuple[int, int] = (224, 224)
51
+
52
+ class BaseDataModule(pl.LightningDataModule):
53
+ cfg: BaseDatasetConfig
54
+
55
+ def __init__(self, cfg: BaseDatasetConfig) -> None:
56
+ super().__init__()
57
+ self.cfg:BaseDatasetConfig = parse_structure(BaseDatasetConfig, cfg)
58
+ self.train_path = cfg.train_path
59
+ self.valid_path = cfg.valid_path
60
+ self.test_path = cfg.test_path
61
+ self.img_size = cfg.image_size
62
+
63
+ def setup(self, stage=None) -> None:
64
+ if stage in [None, "fit"]:
65
+ self.train_dataset = BaseDataset(self.train_path, self.img_size)
66
+ if stage in [None, "fit", "validate"]:
67
+ self.val_dataset = BaseDataset(self.valid_path, self.img_size)
68
+ if stage in [None, "test", "predict"]:
69
+ self.test_dataset = BaseDataset(self.test_path, self.img_size)
70
+
71
+ def general_loader(self, dataset, batch_size) -> DataLoader:
72
+ return DataLoader(
73
+ dataset,
74
+ num_workers=self.cfg.num_workers,
75
+ batch_size=batch_size
76
+ )
77
+
78
+ def train_dataloader(self) -> DataLoader:
79
+ return DataLoader(
80
+ self.train_dataset,
81
+ num_workers=self.cfg.num_workers,
82
+ batch_size=self.cfg.batch_size,
83
+ shuffle=self.cfg.shuffle
84
+ )
85
+
86
+ def val_dataloader(self) -> DataLoader:
87
+ return DataLoader(
88
+ self.val_dataset,
89
+ num_workers=self.cfg.num_workers,
90
+ batch_size=self.cfg.batch_size
91
+ )
92
+
93
+ def test_dataloader(self) -> DataLoader:
94
+ return DataLoader(
95
+ self.test_dataset,
96
+ num_workers=self.cfg.num_workers,
97
+ batch_size=self.cfg.batch_size
98
+ )
data/breakhis.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+ from typing import *
3
+ from dataclasses import dataclass, field
4
+ from PIL import Image
5
+ from utils import parse_structure
6
+ from glob import glob
7
+ from random import shuffle
8
+ from torchvision.transforms import v2
9
+
10
+ import os
11
+ import lightning.pytorch as pl
12
+ import numpy as np
13
+ import torch
14
+ import random
15
+
16
+
17
+ class BreakhisDataset(Dataset):
18
+ def __init__(self, root_dir: str, image_size: Tuple[int, int], subset: str, aug: dict = None) -> None:
19
+ self.root_dir = root_dir
20
+ self.image_size = image_size
21
+ self.classes = {
22
+ 'benign' : 0,
23
+ 'malignant' : 1
24
+ }
25
+ self.ratio = [0.8, 0.1]
26
+ self.subset = subset
27
+ self.aug = aug
28
+
29
+ self.benign_subclasses = ['adenosis', 'fibroadenoma', 'phyllodes_tumor', 'tubular_adenoma']
30
+ self.malignant_subclasses = ['ductal_carcinoma', 'lobular_carcinoma', 'mucinous_carcinoma', 'papillary_carcinoma']
31
+ self.cls2sublst = {
32
+ 'benign' : self.benign_subclasses,
33
+ 'malignant' : self.malignant_subclasses
34
+ }
35
+ self.factors = ['100X', '200X', '400X', '40X']
36
+
37
+ self.sample_paths = []
38
+ self.sample_labels = []
39
+
40
+ random.seed(42)
41
+ for cate in ['benign', 'malignant']:
42
+ for subcls in self.cls2sublst[cate]:
43
+ for factor in self.factors:
44
+ lst = glob(os.path.join(self.root_dir, f'{cate}/*/{subcls}/*/{factor}/*.png'))
45
+ random.shuffle(lst)
46
+
47
+ sublst = self.get_subset(lst)
48
+ self.sample_paths += sublst
49
+ self.sample_labels += [self.classes[cate]] * len(sublst)
50
+
51
+ if self.aug is not None:
52
+ self.transforms = [v2.Resize(self.image_size, antialias=True)] + \
53
+ [getattr(v2, x)(**self.aug[x]) for x in self.aug] + \
54
+ [
55
+ v2.ToDtype(torch.float32, scale=True),
56
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57
+ ]
58
+ else:
59
+ self.transforms = [
60
+ v2.Resize(self.image_size, antialias=True),
61
+ v2.ToDtype(torch.float32, scale=True),
62
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
63
+ ]
64
+
65
+ self.transform = v2.Compose(self.transforms)
66
+
67
+ def get_subset(self, x: list):
68
+ if self.subset == 'train':
69
+ return x[ : int(self.ratio[0] * len(x))]
70
+ elif self.subset == 'valid':
71
+ return x[int(self.ratio[0] * len(x)) : int((self.ratio[0] + self.ratio[1]) * len(x))]
72
+ elif self.subset == 'test':
73
+ return x[int((self.ratio[0] + self.ratio[1]) * len(x)) : ]
74
+ else:
75
+ return ValueError('Unknown subset')
76
+
77
+ def __len__(self) -> int:
78
+ return len(self.sample_paths)
79
+
80
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
81
+ img_path = self.sample_paths[idx]
82
+ label = self.sample_labels[idx]
83
+ image = Image.open(img_path).convert("RGB")
84
+ image = image.resize(self.image_size)
85
+ image = np.array(image)
86
+ image = torch.from_numpy(image).permute(2, 0, 1)
87
+ image = self.transform(image)
88
+
89
+ return image, label
90
+
91
+
92
+ @dataclass
93
+ class BaseDatasetConfig:
94
+ data_source: str = ''
95
+ batch_size:int = 32
96
+ shuffle:bool = True
97
+ num_workers:int = 24
98
+ image_size:Tuple[int, int] = (224, 224)
99
+ aug: dict = field(default_factory=dict)
100
+
101
+ class BreakhisDataModule(pl.LightningDataModule):
102
+ cfg: BaseDatasetConfig
103
+
104
+ def __init__(self, cfg: BaseDatasetConfig) -> None:
105
+ super().__init__()
106
+ self.cfg:BaseDatasetConfig = parse_structure(BaseDatasetConfig, cfg)
107
+ self.data_source = self.cfg.data_source
108
+ self.img_size = self.cfg.image_size
109
+ self.aug = self.cfg.aug
110
+
111
+ def setup(self, stage=None) -> None:
112
+ if stage in [None, "fit"]:
113
+ self.train_dataset = BreakhisDataset(self.data_source, self.img_size, 'train', self.aug)
114
+ if stage in [None, "fit", "validate"]:
115
+ self.val_dataset = BreakhisDataset(self.data_source, self.img_size, 'valid', self.aug)
116
+ if stage in [None, "test", "predict"]:
117
+ self.test_dataset = BreakhisDataset(self.data_source, self.img_size, 'test', self.aug)
118
+
119
+ def general_loader(self, dataset, batch_size) -> DataLoader:
120
+ return DataLoader(
121
+ dataset,
122
+ num_workers=self.cfg.num_workers,
123
+ batch_size=batch_size
124
+ )
125
+
126
+ def train_dataloader(self) -> DataLoader:
127
+ return DataLoader(
128
+ self.train_dataset,
129
+ num_workers=self.cfg.num_workers,
130
+ batch_size=self.cfg.batch_size,
131
+ shuffle=self.cfg.shuffle
132
+ )
133
+
134
+ def val_dataloader(self) -> DataLoader:
135
+ return DataLoader(
136
+ self.val_dataset,
137
+ num_workers=self.cfg.num_workers,
138
+ batch_size=self.cfg.batch_size
139
+ )
140
+
141
+ def test_dataloader(self) -> DataLoader:
142
+ return DataLoader(
143
+ self.test_dataset,
144
+ num_workers=self.cfg.num_workers,
145
+ batch_size=self.cfg.batch_size
146
+ )
data/camelyon.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
2
+ from .base import BaseDatasetConfig, BaseDataModule
3
+
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from typing import *
6
+ from dataclasses import dataclass, field
7
+ from PIL import Image
8
+ from utils import parse_structure
9
+
10
+ import os
11
+ import numpy as np
12
+ import torch
13
+ import albumentations as A
14
+
15
+ class CamelyonDataset(Dataset):
16
+ def __init__(self, root_dir: str, subset: str, image_size: Tuple[int, int]) -> None:
17
+ self.root_dir = root_dir
18
+ self.dataset = Camelyon17Dataset(root_dir=root_dir, download=True).get_subset(subset)
19
+ self.transform = {
20
+ "train" : A.Compose([
21
+ A.HorizontalFlip(),
22
+ A.Affine(scale=(-0.2, 0.2),
23
+ rotate=(-10, 10),
24
+ # shear=(-5, 5),
25
+ keep_ratio=True,
26
+ p=0.5),
27
+ A.OneOf([
28
+ A.MotionBlur(p=0.2),
29
+ A.MedianBlur(blur_limit=3, p=0.1),
30
+ A.Blur(blur_limit=3, p=0.1),
31
+ ], p=0.5),
32
+ A.OneOf([
33
+ A.CLAHE(clip_limit=2),
34
+ A.RandomBrightnessContrast(),
35
+ ], p=0.5),
36
+ A.HueSaturationValue(p=0.25),
37
+ A.Resize(image_size[0], image_size[1])
38
+ ], p=1.0),
39
+ "val" : A.Compose([
40
+ A.Resize(image_size[0], image_size[1])
41
+ ], p=1.0),
42
+ "test" : A.Compose([
43
+ A.Resize(image_size[0], image_size[1])
44
+ ], p=1.0)
45
+ }[subset]
46
+ self.image_size = image_size
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
52
+ (image, label, _) = self.dataset.__getitem__(idx)
53
+ # image = image.resize(self.image_size)
54
+ image = np.array(image)
55
+ image = self.transform(image=image)["image"]
56
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
57
+ return image, label
58
+
59
+ class CamelyonDataModule(BaseDataModule):
60
+ cfg: BaseDatasetConfig
61
+
62
+ def __init__(self, cfg: BaseDatasetConfig) -> None:
63
+ super().__init__(cfg)
64
+ self.cfg:DatasetConfig = parse_structure(BaseDatasetConfig, cfg)
65
+ self.img_size = cfg.image_size
66
+
67
+ def setup(self, stage=None) -> None:
68
+ if stage in [None, "fit"]:
69
+ self.train_dataset = CamelyonDataset(self.cfg.data_source, "train", self.img_size)
70
+ if stage in [None, "fit", "validate"]:
71
+ self.val_dataset = CamelyonDataset(self.cfg.data_source, "val", self.img_size)
72
+ if stage in [None, "test", "predict"]:
73
+ self.test_dataset = CamelyonDataset(self.cfg.data_source, "test", self.img_size)
metrics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base import BaseMetrics, BaseMetricsConfig
metrics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (228 Bytes). View file
 
metrics/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
metrics/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchmetrics import classification
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, Mapping
4
+ from utils import parse_structure
5
+ from torch import Tensor, nn
6
+
7
+ import lightning.pytorch as pl
8
+ import torch
9
+
10
+ @dataclass
11
+ class BaseMetricsConfig:
12
+ metrics_names:list = field(default_factory=list)
13
+ metrics_short_names:list = field(default_factory=list)
14
+
15
+ class BaseMetrics(pl.LightningModule):
16
+ def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None:
17
+ super().__init__(*args, **kwargs)
18
+
19
+ self.cfg: BaseMetricsConfig = parse_structure(BaseMetricsConfig, cfg)
20
+ self.metrics_names = self.cfg.metrics_names
21
+ self.metrics_short_names = self.cfg.metrics_short_names
22
+
23
+ self.metrics = nn.ModuleDict()
24
+ for name, short_name in zip(self.metrics_names, self.metrics_short_names):
25
+ obj = getattr(classification, name)
26
+ metric = obj()
27
+ self.metrics[short_name] = metric
28
+
29
+ print(f"[INFO]: Metrics: {self.metrics}")
30
+
31
+ def __call__(self, pred: Tensor, target: Tensor, prefix:str) -> Dict[str, float]:
32
+ pred = torch.sigmoid(pred).round()
33
+ return {f'{prefix}/{name}': metric(pred, target) for name, metric in self.metrics.items()}