Harshasnade's picture
Deploy Backend (No Frontend)
0966609
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from src.config import Config
class DeepfakeDataset(Dataset):
def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None):
"""
Args:
root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided)
file_paths (list): List of absolute paths to images.
labels (list): List of labels corresponding to file_paths.
phase (str): 'train' or 'val'.
max_samples (int): Optional limit for quick debugging.
"""
self.phase = phase
if file_paths is not None and labels is not None:
self.image_paths = file_paths
self.labels = labels
elif root_dir is not None:
self.image_paths, self.labels = self.scan_directory(root_dir)
else:
raise ValueError("Either root_dir or (file_paths, labels) must be provided.")
if max_samples:
self.image_paths = self.image_paths[:max_samples]
self.labels = self.labels[:max_samples]
self.transform = self._get_transforms()
print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.")
@staticmethod
def scan_directory(root_dir):
image_paths = []
labels = []
print(f"Scanning dataset at {root_dir}...")
# Valid extensions
exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif')
for root, dirs, files in os.walk(root_dir):
for file in files:
if file.lower().endswith(exts):
path = os.path.join(root, file)
# Label inference based on full path
path_lower = path.lower()
label = None
# Prioritize explicit folder names
if "real" in path_lower:
label = 0.0
elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]):
label = 1.0
if label is not None:
image_paths.append(path)
labels.append(label)
return image_paths, labels
def _get_transforms(self):
size = Config.IMAGE_SIZE
if self.phase == 'train':
return A.Compose([
A.Resize(size, size),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.GaussNoise(p=0.2),
# A.GaussianBlur(p=0.1),
# Fixed for newer albumentations versions
A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
else:
return A.Compose([
A.Resize(size, size),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
path = self.image_paths[idx]
label = self.labels[idx]
try:
image = cv2.imread(path)
if image is None:
raise ValueError("Image not found or corrupt")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except Exception as e:
# print(f"Error loading {path}: {e}")
# Fallback to next image
return self.__getitem__((idx + 1) % len(self))
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, torch.tensor(label, dtype=torch.float32)