Spaces:
Configuration error
Configuration error
| from abc import abstractmethod | |
| from typing import Callable, Optional | |
| import lightning as pl | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from src.config import Config | |
| from src.utils.logger import print | |
| class BaseDataset(Dataset): | |
| def __init__( | |
| self, | |
| files: list[str], | |
| labels: list[int], | |
| preprocess: None | Callable = None, | |
| augmentations: None | Callable = None, | |
| shuffle: bool = False, # Shuffles the dataset once | |
| dataset2files: Optional[dict[str, list[str]]] = None, | |
| ): | |
| self.files = files | |
| self.labels = labels | |
| self.preprocess = preprocess | |
| self.augmentations = augmentations | |
| self.dataset2files = dataset2files | |
| if shuffle: | |
| self.shuffle() | |
| def shuffle(self): | |
| # create fixed seed for reproducibility | |
| idx = np.random.RandomState(42).permutation(len(self.files)) | |
| self.files = [self.files[i] for i in idx] | |
| self.labels = [self.labels[i] for i in idx] | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| path = self.files[idx] | |
| image = Image.open(path) | |
| if self.augmentations is not None: | |
| image = self.augmentations(image) | |
| if self.preprocess is not None: | |
| image = self.preprocess(image) | |
| return { | |
| "image": image, | |
| "label": self.labels[idx], | |
| "path": path, | |
| } | |
| def print_statistics(self): | |
| print(f"Number of samples: {len(self.files)}") | |
| unique, counts = np.unique(self.labels, return_counts=True) | |
| print("Class distribution") | |
| names = self.get_class_names() | |
| for u, c in zip(unique, counts): | |
| print(f"Class {u} ({names[u]}): {c}") | |
| def get_class_names(self) -> dict[int, str]: | |
| raise NotImplementedError | |
| class BaseDataModule(pl.LightningDataModule): | |
| def __init__(self, config: Config, preprocess: None | Callable = None): | |
| super().__init__() | |
| self.config = config | |
| self.preprocess = preprocess | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| shuffle=True, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| ) | |