init
Browse files- data/__init__.py +3 -0
- data/__pycache__/__init__.cpython-310.pyc +0 -0
- data/__pycache__/base.cpython-310.pyc +0 -0
- data/__pycache__/breakhis.cpython-310.pyc +0 -0
- data/__pycache__/camelyon.cpython-310.pyc +0 -0
- data/base.py +98 -0
- data/breakhis.py +146 -0
- data/camelyon.py +73 -0
- metrics/__init__.py +1 -0
- metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- metrics/__pycache__/base.cpython-310.pyc +0 -0
- metrics/base.py +33 -0
data/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseDataModule
|
| 2 |
+
from .camelyon import CamelyonDataModule
|
| 3 |
+
from .breakhis import BreakhisDataModule
|
data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (302 Bytes). View file
|
|
|
data/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (3.93 kB). View file
|
|
|
data/__pycache__/breakhis.cpython-310.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
data/__pycache__/camelyon.cpython-310.pyc
ADDED
|
Binary file (3.04 kB). View file
|
|
|
data/base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset, DataLoader
|
| 2 |
+
from typing import *
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from utils import parse_structure
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseDataset(Dataset):
|
| 14 |
+
def __init__(self, root_dir: str, image_size: Tuple[int, int]) -> None:
|
| 15 |
+
self.root_dir = root_dir
|
| 16 |
+
self.image_size = image_size
|
| 17 |
+
self.classes = {folder: idx for idx, folder in enumerate(os.listdir(root_dir))}
|
| 18 |
+
self.image_paths = []
|
| 19 |
+
self.labels = []
|
| 20 |
+
|
| 21 |
+
for class_name, class_idx in self.classes.items():
|
| 22 |
+
class_dir = os.path.join(root_dir, class_name)
|
| 23 |
+
for img_name in os.listdir(class_dir):
|
| 24 |
+
img_path = os.path.join(class_dir, img_name)
|
| 25 |
+
self.image_paths.append(img_path)
|
| 26 |
+
self.labels.append(class_idx)
|
| 27 |
+
|
| 28 |
+
def __len__(self) -> int:
|
| 29 |
+
return len(self.image_paths)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 32 |
+
img_path = self.image_paths[idx]
|
| 33 |
+
label = self.labels[idx]
|
| 34 |
+
image = Image.open(img_path).convert("RGB")
|
| 35 |
+
image = image.resize(self.image_size)
|
| 36 |
+
image = np.array(image)
|
| 37 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
|
| 38 |
+
return image, label
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class BaseDatasetConfig:
|
| 43 |
+
data_source: str = ''
|
| 44 |
+
train_path:str = ''
|
| 45 |
+
valid_path:str = ''
|
| 46 |
+
test_path:str = ''
|
| 47 |
+
batch_size:int = 32
|
| 48 |
+
shuffle:bool = True
|
| 49 |
+
num_workers:int = 24
|
| 50 |
+
image_size:Tuple[int, int] = (224, 224)
|
| 51 |
+
|
| 52 |
+
class BaseDataModule(pl.LightningDataModule):
|
| 53 |
+
cfg: BaseDatasetConfig
|
| 54 |
+
|
| 55 |
+
def __init__(self, cfg: BaseDatasetConfig) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.cfg:BaseDatasetConfig = parse_structure(BaseDatasetConfig, cfg)
|
| 58 |
+
self.train_path = cfg.train_path
|
| 59 |
+
self.valid_path = cfg.valid_path
|
| 60 |
+
self.test_path = cfg.test_path
|
| 61 |
+
self.img_size = cfg.image_size
|
| 62 |
+
|
| 63 |
+
def setup(self, stage=None) -> None:
|
| 64 |
+
if stage in [None, "fit"]:
|
| 65 |
+
self.train_dataset = BaseDataset(self.train_path, self.img_size)
|
| 66 |
+
if stage in [None, "fit", "validate"]:
|
| 67 |
+
self.val_dataset = BaseDataset(self.valid_path, self.img_size)
|
| 68 |
+
if stage in [None, "test", "predict"]:
|
| 69 |
+
self.test_dataset = BaseDataset(self.test_path, self.img_size)
|
| 70 |
+
|
| 71 |
+
def general_loader(self, dataset, batch_size) -> DataLoader:
|
| 72 |
+
return DataLoader(
|
| 73 |
+
dataset,
|
| 74 |
+
num_workers=self.cfg.num_workers,
|
| 75 |
+
batch_size=batch_size
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def train_dataloader(self) -> DataLoader:
|
| 79 |
+
return DataLoader(
|
| 80 |
+
self.train_dataset,
|
| 81 |
+
num_workers=self.cfg.num_workers,
|
| 82 |
+
batch_size=self.cfg.batch_size,
|
| 83 |
+
shuffle=self.cfg.shuffle
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def val_dataloader(self) -> DataLoader:
|
| 87 |
+
return DataLoader(
|
| 88 |
+
self.val_dataset,
|
| 89 |
+
num_workers=self.cfg.num_workers,
|
| 90 |
+
batch_size=self.cfg.batch_size
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def test_dataloader(self) -> DataLoader:
|
| 94 |
+
return DataLoader(
|
| 95 |
+
self.test_dataset,
|
| 96 |
+
num_workers=self.cfg.num_workers,
|
| 97 |
+
batch_size=self.cfg.batch_size
|
| 98 |
+
)
|
data/breakhis.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset, DataLoader
|
| 2 |
+
from typing import *
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from utils import parse_structure
|
| 6 |
+
from glob import glob
|
| 7 |
+
from random import shuffle
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import lightning.pytorch as pl
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BreakhisDataset(Dataset):
|
| 18 |
+
def __init__(self, root_dir: str, image_size: Tuple[int, int], subset: str, aug: dict = None) -> None:
|
| 19 |
+
self.root_dir = root_dir
|
| 20 |
+
self.image_size = image_size
|
| 21 |
+
self.classes = {
|
| 22 |
+
'benign' : 0,
|
| 23 |
+
'malignant' : 1
|
| 24 |
+
}
|
| 25 |
+
self.ratio = [0.8, 0.1]
|
| 26 |
+
self.subset = subset
|
| 27 |
+
self.aug = aug
|
| 28 |
+
|
| 29 |
+
self.benign_subclasses = ['adenosis', 'fibroadenoma', 'phyllodes_tumor', 'tubular_adenoma']
|
| 30 |
+
self.malignant_subclasses = ['ductal_carcinoma', 'lobular_carcinoma', 'mucinous_carcinoma', 'papillary_carcinoma']
|
| 31 |
+
self.cls2sublst = {
|
| 32 |
+
'benign' : self.benign_subclasses,
|
| 33 |
+
'malignant' : self.malignant_subclasses
|
| 34 |
+
}
|
| 35 |
+
self.factors = ['100X', '200X', '400X', '40X']
|
| 36 |
+
|
| 37 |
+
self.sample_paths = []
|
| 38 |
+
self.sample_labels = []
|
| 39 |
+
|
| 40 |
+
random.seed(42)
|
| 41 |
+
for cate in ['benign', 'malignant']:
|
| 42 |
+
for subcls in self.cls2sublst[cate]:
|
| 43 |
+
for factor in self.factors:
|
| 44 |
+
lst = glob(os.path.join(self.root_dir, f'{cate}/*/{subcls}/*/{factor}/*.png'))
|
| 45 |
+
random.shuffle(lst)
|
| 46 |
+
|
| 47 |
+
sublst = self.get_subset(lst)
|
| 48 |
+
self.sample_paths += sublst
|
| 49 |
+
self.sample_labels += [self.classes[cate]] * len(sublst)
|
| 50 |
+
|
| 51 |
+
if self.aug is not None:
|
| 52 |
+
self.transforms = [v2.Resize(self.image_size, antialias=True)] + \
|
| 53 |
+
[getattr(v2, x)(**self.aug[x]) for x in self.aug] + \
|
| 54 |
+
[
|
| 55 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 56 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
self.transforms = [
|
| 60 |
+
v2.Resize(self.image_size, antialias=True),
|
| 61 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 62 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
self.transform = v2.Compose(self.transforms)
|
| 66 |
+
|
| 67 |
+
def get_subset(self, x: list):
|
| 68 |
+
if self.subset == 'train':
|
| 69 |
+
return x[ : int(self.ratio[0] * len(x))]
|
| 70 |
+
elif self.subset == 'valid':
|
| 71 |
+
return x[int(self.ratio[0] * len(x)) : int((self.ratio[0] + self.ratio[1]) * len(x))]
|
| 72 |
+
elif self.subset == 'test':
|
| 73 |
+
return x[int((self.ratio[0] + self.ratio[1]) * len(x)) : ]
|
| 74 |
+
else:
|
| 75 |
+
return ValueError('Unknown subset')
|
| 76 |
+
|
| 77 |
+
def __len__(self) -> int:
|
| 78 |
+
return len(self.sample_paths)
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 81 |
+
img_path = self.sample_paths[idx]
|
| 82 |
+
label = self.sample_labels[idx]
|
| 83 |
+
image = Image.open(img_path).convert("RGB")
|
| 84 |
+
image = image.resize(self.image_size)
|
| 85 |
+
image = np.array(image)
|
| 86 |
+
image = torch.from_numpy(image).permute(2, 0, 1)
|
| 87 |
+
image = self.transform(image)
|
| 88 |
+
|
| 89 |
+
return image, label
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class BaseDatasetConfig:
|
| 94 |
+
data_source: str = ''
|
| 95 |
+
batch_size:int = 32
|
| 96 |
+
shuffle:bool = True
|
| 97 |
+
num_workers:int = 24
|
| 98 |
+
image_size:Tuple[int, int] = (224, 224)
|
| 99 |
+
aug: dict = field(default_factory=dict)
|
| 100 |
+
|
| 101 |
+
class BreakhisDataModule(pl.LightningDataModule):
|
| 102 |
+
cfg: BaseDatasetConfig
|
| 103 |
+
|
| 104 |
+
def __init__(self, cfg: BaseDatasetConfig) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.cfg:BaseDatasetConfig = parse_structure(BaseDatasetConfig, cfg)
|
| 107 |
+
self.data_source = self.cfg.data_source
|
| 108 |
+
self.img_size = self.cfg.image_size
|
| 109 |
+
self.aug = self.cfg.aug
|
| 110 |
+
|
| 111 |
+
def setup(self, stage=None) -> None:
|
| 112 |
+
if stage in [None, "fit"]:
|
| 113 |
+
self.train_dataset = BreakhisDataset(self.data_source, self.img_size, 'train', self.aug)
|
| 114 |
+
if stage in [None, "fit", "validate"]:
|
| 115 |
+
self.val_dataset = BreakhisDataset(self.data_source, self.img_size, 'valid', self.aug)
|
| 116 |
+
if stage in [None, "test", "predict"]:
|
| 117 |
+
self.test_dataset = BreakhisDataset(self.data_source, self.img_size, 'test', self.aug)
|
| 118 |
+
|
| 119 |
+
def general_loader(self, dataset, batch_size) -> DataLoader:
|
| 120 |
+
return DataLoader(
|
| 121 |
+
dataset,
|
| 122 |
+
num_workers=self.cfg.num_workers,
|
| 123 |
+
batch_size=batch_size
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def train_dataloader(self) -> DataLoader:
|
| 127 |
+
return DataLoader(
|
| 128 |
+
self.train_dataset,
|
| 129 |
+
num_workers=self.cfg.num_workers,
|
| 130 |
+
batch_size=self.cfg.batch_size,
|
| 131 |
+
shuffle=self.cfg.shuffle
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def val_dataloader(self) -> DataLoader:
|
| 135 |
+
return DataLoader(
|
| 136 |
+
self.val_dataset,
|
| 137 |
+
num_workers=self.cfg.num_workers,
|
| 138 |
+
batch_size=self.cfg.batch_size
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def test_dataloader(self) -> DataLoader:
|
| 142 |
+
return DataLoader(
|
| 143 |
+
self.test_dataset,
|
| 144 |
+
num_workers=self.cfg.num_workers,
|
| 145 |
+
batch_size=self.cfg.batch_size
|
| 146 |
+
)
|
data/camelyon.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
|
| 2 |
+
from .base import BaseDatasetConfig, BaseDataModule
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from typing import *
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from utils import parse_structure
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import albumentations as A
|
| 14 |
+
|
| 15 |
+
class CamelyonDataset(Dataset):
|
| 16 |
+
def __init__(self, root_dir: str, subset: str, image_size: Tuple[int, int]) -> None:
|
| 17 |
+
self.root_dir = root_dir
|
| 18 |
+
self.dataset = Camelyon17Dataset(root_dir=root_dir, download=True).get_subset(subset)
|
| 19 |
+
self.transform = {
|
| 20 |
+
"train" : A.Compose([
|
| 21 |
+
A.HorizontalFlip(),
|
| 22 |
+
A.Affine(scale=(-0.2, 0.2),
|
| 23 |
+
rotate=(-10, 10),
|
| 24 |
+
# shear=(-5, 5),
|
| 25 |
+
keep_ratio=True,
|
| 26 |
+
p=0.5),
|
| 27 |
+
A.OneOf([
|
| 28 |
+
A.MotionBlur(p=0.2),
|
| 29 |
+
A.MedianBlur(blur_limit=3, p=0.1),
|
| 30 |
+
A.Blur(blur_limit=3, p=0.1),
|
| 31 |
+
], p=0.5),
|
| 32 |
+
A.OneOf([
|
| 33 |
+
A.CLAHE(clip_limit=2),
|
| 34 |
+
A.RandomBrightnessContrast(),
|
| 35 |
+
], p=0.5),
|
| 36 |
+
A.HueSaturationValue(p=0.25),
|
| 37 |
+
A.Resize(image_size[0], image_size[1])
|
| 38 |
+
], p=1.0),
|
| 39 |
+
"val" : A.Compose([
|
| 40 |
+
A.Resize(image_size[0], image_size[1])
|
| 41 |
+
], p=1.0),
|
| 42 |
+
"test" : A.Compose([
|
| 43 |
+
A.Resize(image_size[0], image_size[1])
|
| 44 |
+
], p=1.0)
|
| 45 |
+
}[subset]
|
| 46 |
+
self.image_size = image_size
|
| 47 |
+
|
| 48 |
+
def __len__(self) -> int:
|
| 49 |
+
return len(self.dataset)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 52 |
+
(image, label, _) = self.dataset.__getitem__(idx)
|
| 53 |
+
# image = image.resize(self.image_size)
|
| 54 |
+
image = np.array(image)
|
| 55 |
+
image = self.transform(image=image)["image"]
|
| 56 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
|
| 57 |
+
return image, label
|
| 58 |
+
|
| 59 |
+
class CamelyonDataModule(BaseDataModule):
|
| 60 |
+
cfg: BaseDatasetConfig
|
| 61 |
+
|
| 62 |
+
def __init__(self, cfg: BaseDatasetConfig) -> None:
|
| 63 |
+
super().__init__(cfg)
|
| 64 |
+
self.cfg:DatasetConfig = parse_structure(BaseDatasetConfig, cfg)
|
| 65 |
+
self.img_size = cfg.image_size
|
| 66 |
+
|
| 67 |
+
def setup(self, stage=None) -> None:
|
| 68 |
+
if stage in [None, "fit"]:
|
| 69 |
+
self.train_dataset = CamelyonDataset(self.cfg.data_source, "train", self.img_size)
|
| 70 |
+
if stage in [None, "fit", "validate"]:
|
| 71 |
+
self.val_dataset = CamelyonDataset(self.cfg.data_source, "val", self.img_size)
|
| 72 |
+
if stage in [None, "test", "predict"]:
|
| 73 |
+
self.test_dataset = CamelyonDataset(self.cfg.data_source, "test", self.img_size)
|
metrics/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseMetrics, BaseMetricsConfig
|
metrics/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (228 Bytes). View file
|
|
|
metrics/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
metrics/base.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchmetrics import classification
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, Mapping
|
| 4 |
+
from utils import parse_structure
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
import lightning.pytorch as pl
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class BaseMetricsConfig:
|
| 12 |
+
metrics_names:list = field(default_factory=list)
|
| 13 |
+
metrics_short_names:list = field(default_factory=list)
|
| 14 |
+
|
| 15 |
+
class BaseMetrics(pl.LightningModule):
|
| 16 |
+
def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None:
|
| 17 |
+
super().__init__(*args, **kwargs)
|
| 18 |
+
|
| 19 |
+
self.cfg: BaseMetricsConfig = parse_structure(BaseMetricsConfig, cfg)
|
| 20 |
+
self.metrics_names = self.cfg.metrics_names
|
| 21 |
+
self.metrics_short_names = self.cfg.metrics_short_names
|
| 22 |
+
|
| 23 |
+
self.metrics = nn.ModuleDict()
|
| 24 |
+
for name, short_name in zip(self.metrics_names, self.metrics_short_names):
|
| 25 |
+
obj = getattr(classification, name)
|
| 26 |
+
metric = obj()
|
| 27 |
+
self.metrics[short_name] = metric
|
| 28 |
+
|
| 29 |
+
print(f"[INFO]: Metrics: {self.metrics}")
|
| 30 |
+
|
| 31 |
+
def __call__(self, pred: Tensor, target: Tensor, prefix:str) -> Dict[str, float]:
|
| 32 |
+
pred = torch.sigmoid(pred).round()
|
| 33 |
+
return {f'{prefix}/{name}': metric(pred, target) for name, metric in self.metrics.items()}
|