GenD-Sentinel / src /dataset /data_module.py
yermandy's picture
init
c29babb
from typing import Callable
from torch.utils.data import DataLoader
from src.config import Config
from src.utils import logger
from .augmentations import init_augmentations
from .base import BaseDataModule
from .dataset import DeepfakeDataset
class DeepfakeDataModule(BaseDataModule):
def __init__(self, config: Config, preprocess: None | Callable = None):
super().__init__(config, preprocess)
def setup(self, stage: str):
# Initialize datasets
if stage == "fit" or stage == "validate":
logger.print("\n[blue]Creating training dataset")
self.train_dataset = DeepfakeDataset(
self.config.trn_files,
self.preprocess,
augmentations=init_augmentations(self.config.augmentations),
binary=self.config.binary_labels,
limit_files=self.config.limit_trn_files,
load_pairs=self.config.load_pairs,
)
self.train_dataset.print_statistics()
logger.print("\n[blue]Creating validation dataset")
self.val_dataset = DeepfakeDataset(
self.config.val_files,
self.preprocess,
shuffle=True,
binary=self.config.binary_labels,
limit_files=self.config.limit_val_files,
)
self.val_dataset.print_statistics()
if stage == "test":
logger.print("\nCreating test dataset")
self.test_dataset = DeepfakeDataset(
self.config.tst_files,
self.preprocess,
augmentations=init_augmentations(self.config.test_augmentations),
binary=self.config.binary_labels,
limit_files=self.config.limit_tst_files,
)
self.test_dataset.print_statistics()
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
shuffle=True,
drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.config.mini_batch_size,
num_workers=self.config.num_workers,
pin_memory=True,
)