File size: 2,859 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from abc import abstractmethod
from typing import Callable, Optional

import lightning as pl
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from src.config import Config
from src.utils.logger import print


class BaseDataset(Dataset):
    def __init__(
        self,
        files: list[str],
        labels: list[int],
        preprocess: None | Callable = None,
        augmentations: None | Callable = None,
        shuffle: bool = False,  # Shuffles the dataset once
        dataset2files: Optional[dict[str, list[str]]] = None,
    ):
        self.files = files
        self.labels = labels

        self.preprocess = preprocess
        self.augmentations = augmentations

        self.dataset2files = dataset2files

        if shuffle:
            self.shuffle()

    def shuffle(self):
        # create fixed seed for reproducibility
        idx = np.random.RandomState(42).permutation(len(self.files))
        self.files = [self.files[i] for i in idx]
        self.labels = [self.labels[i] for i in idx]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        image = Image.open(path)
        if self.augmentations is not None:
            image = self.augmentations(image)
        if self.preprocess is not None:
            image = self.preprocess(image)
        return {
            "image": image,
            "label": self.labels[idx],
            "path": path,
        }

    def print_statistics(self):
        print(f"Number of samples: {len(self.files)}")
        unique, counts = np.unique(self.labels, return_counts=True)
        print("Class distribution")
        names = self.get_class_names()
        for u, c in zip(unique, counts):
            print(f"Class {u} ({names[u]}): {c}")

    @abstractmethod
    def get_class_names(self) -> dict[int, str]:
        raise NotImplementedError


class BaseDataModule(pl.LightningDataModule):
    def __init__(self, config: Config, preprocess: None | Callable = None):
        super().__init__()
        self.config = config
        self.preprocess = preprocess

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.mini_batch_size,
            num_workers=self.config.num_workers,
            pin_memory=True,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.config.mini_batch_size,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.config.mini_batch_size,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )