Spaces:
Configuration error
Configuration error
| from typing import Callable | |
| from torch.utils.data import DataLoader | |
| from src.config import Config | |
| from src.utils import logger | |
| from .augmentations import init_augmentations | |
| from .base import BaseDataModule | |
| from .dataset import DeepfakeDataset | |
| class DeepfakeDataModule(BaseDataModule): | |
| def __init__(self, config: Config, preprocess: None | Callable = None): | |
| super().__init__(config, preprocess) | |
| def setup(self, stage: str): | |
| # Initialize datasets | |
| if stage == "fit" or stage == "validate": | |
| logger.print("\n[blue]Creating training dataset") | |
| self.train_dataset = DeepfakeDataset( | |
| self.config.trn_files, | |
| self.preprocess, | |
| augmentations=init_augmentations(self.config.augmentations), | |
| binary=self.config.binary_labels, | |
| limit_files=self.config.limit_trn_files, | |
| load_pairs=self.config.load_pairs, | |
| ) | |
| self.train_dataset.print_statistics() | |
| logger.print("\n[blue]Creating validation dataset") | |
| self.val_dataset = DeepfakeDataset( | |
| self.config.val_files, | |
| self.preprocess, | |
| shuffle=True, | |
| binary=self.config.binary_labels, | |
| limit_files=self.config.limit_val_files, | |
| ) | |
| self.val_dataset.print_statistics() | |
| if stage == "test": | |
| logger.print("\nCreating test dataset") | |
| self.test_dataset = DeepfakeDataset( | |
| self.config.tst_files, | |
| self.preprocess, | |
| augmentations=init_augmentations(self.config.test_augmentations), | |
| binary=self.config.binary_labels, | |
| limit_files=self.config.limit_tst_files, | |
| ) | |
| self.test_dataset.print_statistics() | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| shuffle=True, | |
| drop_last=True, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.config.mini_batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| ) | |