import os import glob import random from PIL import Image from torch.utils.data import Dataset class GenImageDataset(Dataset): def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42): """ Args: root_dir (str): Path to the root directory (e.g., 'genimage_test/test'). transform (callable, optional): Optional transform to be applied on a sample. sample_ratio (float): Ratio of data to sample (0.0 to 1.0). seed (int): Random seed for reproducibility. """ self.root_dir = root_dir self.transform = transform self.classes = ['ai'] self.samples = [] # Walk through the directory structure # Expected: root_dir//[nature|ai]/ # Sometimes the structure is root_dir/test//[nature|ai]/ # Get all generator folders if not os.path.exists(root_dir): raise FileNotFoundError(f"Directory not found: {root_dir}") # Auto-detect if there's a 'test' subdirectory (common in genimage datasets) test_dir = os.path.join(root_dir, 'test') if os.path.exists(test_dir) and os.path.isdir(test_dir): # Check if test_dir has generator folders test_contents = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))] if test_contents: print(f"Detected 'test' subdirectory, using {test_dir}") root_dir = test_dir generators = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] print(f"Found {len(generators)} generator folders: {generators}") all_samples = [] for generator in generators: gen_path = os.path.join(root_dir, generator) for cls in self.classes: cls_path = os.path.join(gen_path, cls) if os.path.exists(cls_path): images = [] # Recursive search for image files extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff'] for ext in extensions: images.extend(glob.glob(os.path.join(cls_path, '**', ext), recursive=True)) # Also check for uppercase extensions images.extend(glob.glob(os.path.join(cls_path, '**', ext.upper()), recursive=True)) # Deduplicate to handle case-insensitive filesystems (Windows) images = sorted(list(set(images))) # Label: 0 for nature (real), 1 for ai (fake) label = 1 for img_path in images: all_samples.append((img_path, label)) # Sampling random.seed(seed) sample_size = int(len(all_samples) * sample_ratio) if sample_size > 0: self.samples = random.sample(all_samples, sample_size) else: self.samples = all_samples # Fallback if ratio is too small but we want something print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] try: image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label, img_path except Exception as e: print(f"Error loading {img_path}: {e}") # Return a dummy or handle gracefully? For now, let's just skip or error. # In a real training loop, we might want to return None and use a collate_fn to filter. # For simplicity here, we'll raise. raise e