dftest1 / src /data /genimage_dataset.py
akcanca's picture
Upload 110 files (#1)
07fe054 verified
import os
import glob
import random
from PIL import Image
from torch.utils.data import Dataset
class GenImageDataset(Dataset):
def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42):
"""
Args:
root_dir (str): Path to the root directory (e.g., 'genimage_test/test').
transform (callable, optional): Optional transform to be applied on a sample.
sample_ratio (float): Ratio of data to sample (0.0 to 1.0).
seed (int): Random seed for reproducibility.
"""
self.root_dir = root_dir
self.transform = transform
self.classes = ['ai']
self.samples = []
# Walk through the directory structure
# Expected: root_dir/<generator>/[nature|ai]/<image>
# Sometimes the structure is root_dir/test/<generator>/[nature|ai]/<image>
# Get all generator folders
if not os.path.exists(root_dir):
raise FileNotFoundError(f"Directory not found: {root_dir}")
# Auto-detect if there's a 'test' subdirectory (common in genimage datasets)
test_dir = os.path.join(root_dir, 'test')
if os.path.exists(test_dir) and os.path.isdir(test_dir):
# Check if test_dir has generator folders
test_contents = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]
if test_contents:
print(f"Detected 'test' subdirectory, using {test_dir}")
root_dir = test_dir
generators = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
print(f"Found {len(generators)} generator folders: {generators}")
all_samples = []
for generator in generators:
gen_path = os.path.join(root_dir, generator)
for cls in self.classes:
cls_path = os.path.join(gen_path, cls)
if os.path.exists(cls_path):
images = []
# Recursive search for image files
extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
for ext in extensions:
images.extend(glob.glob(os.path.join(cls_path, '**', ext), recursive=True))
# Also check for uppercase extensions
images.extend(glob.glob(os.path.join(cls_path, '**', ext.upper()), recursive=True))
# Deduplicate to handle case-insensitive filesystems (Windows)
images = sorted(list(set(images)))
# Label: 0 for nature (real), 1 for ai (fake)
label = 1
for img_path in images:
all_samples.append((img_path, label))
# Sampling
random.seed(seed)
sample_size = int(len(all_samples) * sample_ratio)
if sample_size > 0:
self.samples = random.sample(all_samples, sample_size)
else:
self.samples = all_samples # Fallback if ratio is too small but we want something
print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
try:
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label, img_path
except Exception as e:
print(f"Error loading {img_path}: {e}")
# Return a dummy or handle gracefully? For now, let's just skip or error.
# In a real training loop, we might want to return None and use a collate_fn to filter.
# For simplicity here, we'll raise.
raise e