|
|
import os |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
from src.config import Config |
|
|
|
|
|
class DeepfakeDataset(Dataset): |
|
|
def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None): |
|
|
""" |
|
|
Args: |
|
|
root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided) |
|
|
file_paths (list): List of absolute paths to images. |
|
|
labels (list): List of labels corresponding to file_paths. |
|
|
phase (str): 'train' or 'val'. |
|
|
max_samples (int): Optional limit for quick debugging. |
|
|
""" |
|
|
self.phase = phase |
|
|
|
|
|
if file_paths is not None and labels is not None: |
|
|
self.image_paths = file_paths |
|
|
self.labels = labels |
|
|
elif root_dir is not None: |
|
|
self.image_paths, self.labels = self.scan_directory(root_dir) |
|
|
else: |
|
|
raise ValueError("Either root_dir or (file_paths, labels) must be provided.") |
|
|
|
|
|
if max_samples: |
|
|
self.image_paths = self.image_paths[:max_samples] |
|
|
self.labels = self.labels[:max_samples] |
|
|
|
|
|
self.transform = self._get_transforms() |
|
|
|
|
|
print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.") |
|
|
|
|
|
@staticmethod |
|
|
def scan_directory(root_dir): |
|
|
image_paths = [] |
|
|
labels = [] |
|
|
print(f"Scanning dataset at {root_dir}...") |
|
|
|
|
|
|
|
|
exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif') |
|
|
|
|
|
for root, dirs, files in os.walk(root_dir): |
|
|
for file in files: |
|
|
if file.lower().endswith(exts): |
|
|
path = os.path.join(root, file) |
|
|
|
|
|
path_lower = path.lower() |
|
|
|
|
|
label = None |
|
|
|
|
|
if "real" in path_lower: |
|
|
label = 0.0 |
|
|
elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]): |
|
|
label = 1.0 |
|
|
|
|
|
if label is not None: |
|
|
image_paths.append(path) |
|
|
labels.append(label) |
|
|
|
|
|
return image_paths, labels |
|
|
|
|
|
def _get_transforms(self): |
|
|
size = Config.IMAGE_SIZE |
|
|
if self.phase == 'train': |
|
|
return A.Compose([ |
|
|
A.Resize(size, size), |
|
|
A.HorizontalFlip(p=0.5), |
|
|
A.RandomBrightnessContrast(p=0.2), |
|
|
A.GaussNoise(p=0.2), |
|
|
|
|
|
|
|
|
A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3), |
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
else: |
|
|
return A.Compose([ |
|
|
A.Resize(size, size), |
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
path = self.image_paths[idx] |
|
|
label = self.labels[idx] |
|
|
|
|
|
try: |
|
|
image = cv2.imread(path) |
|
|
if image is None: |
|
|
raise ValueError("Image not found or corrupt") |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
return self.__getitem__((idx + 1) % len(self)) |
|
|
|
|
|
if self.transform: |
|
|
augmented = self.transform(image=image) |
|
|
image = augmented['image'] |
|
|
|
|
|
return image, torch.tensor(label, dtype=torch.float32) |
|
|
|