File size: 4,060 Bytes
0966609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from src.config import Config

class DeepfakeDataset(Dataset):
    def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None):
        """
        Args:
            root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided)
            file_paths (list): List of absolute paths to images.
            labels (list): List of labels corresponding to file_paths.
            phase (str): 'train' or 'val'.
            max_samples (int): Optional limit for quick debugging.
        """
        self.phase = phase
        
        if file_paths is not None and labels is not None:
            self.image_paths = file_paths
            self.labels = labels
        elif root_dir is not None:
            self.image_paths, self.labels = self.scan_directory(root_dir)
        else:
            raise ValueError("Either root_dir or (file_paths, labels) must be provided.")
            
        if max_samples:
            self.image_paths = self.image_paths[:max_samples]
            self.labels = self.labels[:max_samples]
            
        self.transform = self._get_transforms()
        
        print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.")

    @staticmethod
    def scan_directory(root_dir):
        image_paths = []
        labels = []
        print(f"Scanning dataset at {root_dir}...")
        
        # Valid extensions
        exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif')
        
        for root, dirs, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(exts):
                    path = os.path.join(root, file)
                    # Label inference based on full path
                    path_lower = path.lower()
                    
                    label = None
                    # Prioritize explicit folder names
                    if "real" in path_lower:
                        label = 0.0
                    elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]):
                        label = 1.0
                    
                    if label is not None:
                        image_paths.append(path)
                        labels.append(label)
        
        return image_paths, labels

    def _get_transforms(self):
        size = Config.IMAGE_SIZE
        if self.phase == 'train':
            return A.Compose([
                A.Resize(size, size),
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.2),
                A.GaussNoise(p=0.2),
                # A.GaussianBlur(p=0.1),
                # Fixed for newer albumentations versions
                A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ])
        else:
            return A.Compose([
                A.Resize(size, size),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            image = cv2.imread(path)
            if image is None:
                raise ValueError("Image not found or corrupt")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            # print(f"Error loading {path}: {e}")
            # Fallback to next image
            return self.__getitem__((idx + 1) % len(self))
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        return image, torch.tensor(label, dtype=torch.float32)