| | import os |
| | import cv2 |
| | import torch |
| | import numpy as np |
| | from torch.utils.data import Dataset |
| | import albumentations as A |
| | from albumentations.pytorch import ToTensorV2 |
| | from src.config import Config |
| |
|
| | class DeepfakeDataset(Dataset): |
| | def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None): |
| | """ |
| | Args: |
| | root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided) |
| | file_paths (list): List of absolute paths to images. |
| | labels (list): List of labels corresponding to file_paths. |
| | phase (str): 'train' or 'val'. |
| | max_samples (int): Optional limit for quick debugging. |
| | """ |
| | self.phase = phase |
| | |
| | if file_paths is not None and labels is not None: |
| | self.image_paths = file_paths |
| | self.labels = labels |
| | elif root_dir is not None: |
| | self.image_paths, self.labels = self.scan_directory(root_dir) |
| | else: |
| | raise ValueError("Either root_dir or (file_paths, labels) must be provided.") |
| | |
| | if max_samples: |
| | self.image_paths = self.image_paths[:max_samples] |
| | self.labels = self.labels[:max_samples] |
| | |
| | self.transform = self._get_transforms() |
| | |
| | print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.") |
| |
|
| | @staticmethod |
| | def scan_directory(root_dir): |
| | image_paths = [] |
| | labels = [] |
| | print(f"Scanning dataset at {root_dir}...") |
| | |
| | |
| | exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif') |
| | |
| | for root, dirs, files in os.walk(root_dir): |
| | for file in files: |
| | if file.lower().endswith(exts): |
| | path = os.path.join(root, file) |
| | |
| | path_lower = path.lower() |
| | |
| | label = None |
| | |
| | if "real" in path_lower: |
| | label = 0.0 |
| | elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]): |
| | label = 1.0 |
| | |
| | if label is not None: |
| | image_paths.append(path) |
| | labels.append(label) |
| | |
| | return image_paths, labels |
| |
|
| | def _get_transforms(self): |
| | size = Config.IMAGE_SIZE |
| | if self.phase == 'train': |
| | return A.Compose([ |
| | A.Resize(size, size), |
| | A.HorizontalFlip(p=0.5), |
| | A.RandomBrightnessContrast(p=0.2), |
| | A.GaussNoise(p=0.2), |
| | |
| | |
| | A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3), |
| | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| | ToTensorV2(), |
| | ]) |
| | else: |
| | return A.Compose([ |
| | A.Resize(size, size), |
| | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| | ToTensorV2(), |
| | ]) |
| |
|
| | def __len__(self): |
| | return len(self.image_paths) |
| |
|
| | def __getitem__(self, idx): |
| | path = self.image_paths[idx] |
| | label = self.labels[idx] |
| | |
| | try: |
| | image = cv2.imread(path) |
| | if image is None: |
| | raise ValueError("Image not found or corrupt") |
| | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | except Exception as e: |
| | |
| | |
| | return self.__getitem__((idx + 1) % len(self)) |
| | |
| | if self.transform: |
| | augmented = self.transform(image=image) |
| | image = augmented['image'] |
| | |
| | return image, torch.tensor(label, dtype=torch.float32) |
| |
|