yermandy's picture
init
c29babb
from abc import abstractmethod
from typing import Callable, Optional
import lightning as pl
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from src.config import Config
from src.utils.logger import print
class BaseDataset(Dataset):
def __init__(
self,
files: list[str],
labels: list[int],
preprocess: None | Callable = None,
augmentations: None | Callable = None,
shuffle: bool = False, # Shuffles the dataset once
dataset2files: Optional[dict[str, list[str]]] = None,
):
self.files = files
self.labels = labels
self.preprocess = preprocess
self.augmentations = augmentations
self.dataset2files = dataset2files
if shuffle:
self.shuffle()
def shuffle(self):
# create fixed seed for reproducibility
idx = np.random.RandomState(42).permutation(len(self.files))
self.files = [self.files[i] for i in idx]
self.labels = [self.labels[i] for i in idx]
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
path = self.files[idx]
image = Image.open(path)
if self.augmentations is not None:
image = self.augmentations(image)
if self.preprocess is not None:
image = self.preprocess(image)
return {
"image": image,
"label": self.labels[idx],
"path": path,
}
def print_statistics(self):
print(f"Number of samples: {len(self.files)}")
unique, counts = np.unique(self.labels, return_counts=True)
print("Class distribution")
names = self.get_class_names()
for u, c in zip(unique, counts):
print(f"Class {u} ({names[u]}): {c}")
@abstractmethod
def get_class_names(self) -> dict[int, str]:
raise NotImplementedError
class BaseDataModule(pl.LightningDataModule):
def __init__(self, config: Config, preprocess: None | Callable = None):
super().__init__()
self.config = config
self.preprocess = preprocess
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
)