File size: 3,970 Bytes
07fe054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import glob
import random
from PIL import Image
from torch.utils.data import Dataset

class GenImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42):
        """

        Args:

            root_dir (str): Path to the root directory (e.g., 'genimage_test/test').

            transform (callable, optional): Optional transform to be applied on a sample.

            sample_ratio (float): Ratio of data to sample (0.0 to 1.0).

            seed (int): Random seed for reproducibility.

        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['ai']
        self.samples = []

        # Walk through the directory structure
        # Expected: root_dir/<generator>/[nature|ai]/<image>
        # Sometimes the structure is root_dir/test/<generator>/[nature|ai]/<image>
        
        # Get all generator folders
        if not os.path.exists(root_dir):
             raise FileNotFoundError(f"Directory not found: {root_dir}")

        # Auto-detect if there's a 'test' subdirectory (common in genimage datasets)
        test_dir = os.path.join(root_dir, 'test')
        if os.path.exists(test_dir) and os.path.isdir(test_dir):
            # Check if test_dir has generator folders
            test_contents = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]
            if test_contents:
                print(f"Detected 'test' subdirectory, using {test_dir}")
                root_dir = test_dir

        generators = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        print(f"Found {len(generators)} generator folders: {generators}")
        
        all_samples = []
        for generator in generators:
            gen_path = os.path.join(root_dir, generator)
            for cls in self.classes:
                cls_path = os.path.join(gen_path, cls)
                if os.path.exists(cls_path):
                    images = []
                    # Recursive search for image files
                    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
                    for ext in extensions:
                        images.extend(glob.glob(os.path.join(cls_path, '**', ext), recursive=True))
                        # Also check for uppercase extensions
                        images.extend(glob.glob(os.path.join(cls_path, '**', ext.upper()), recursive=True))
                    # Deduplicate to handle case-insensitive filesystems (Windows)
                    images = sorted(list(set(images)))
                    
                    # Label: 0 for nature (real), 1 for ai (fake)
                    label = 1
                    for img_path in images:
                        all_samples.append((img_path, label))
        
        # Sampling
        random.seed(seed)
        sample_size = int(len(all_samples) * sample_ratio)
        if sample_size > 0:
            self.samples = random.sample(all_samples, sample_size)
        else:
            self.samples = all_samples # Fallback if ratio is too small but we want something
            
        print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label, img_path
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a dummy or handle gracefully? For now, let's just skip or error.
            # In a real training loop, we might want to return None and use a collate_fn to filter.
            # For simplicity here, we'll raise.
            raise e