import os import glob import json import random from PIL import Image from torch.utils.data import Dataset class DragonDataset(Dataset): def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42): """ Args: root_dir (str): Path to the root directory (e.g., 'dataset/dragon/dragon_train_xs'). transform (callable, optional): Optional transform to be applied on a sample. sample_ratio (float): Ratio of data to sample (0.0 to 1.0). seed (int): Random seed for reproducibility. """ self.root_dir = root_dir self.transform = transform self.samples = [] if not os.path.exists(root_dir): raise FileNotFoundError(f"Directory not found: {root_dir}") # Get all png files all_images = glob.glob(os.path.join(root_dir, '*.png')) all_samples = [] for img_path in all_images: # All images in Dragon dataset are generated (fake) label = 1 all_samples.append((img_path, label)) # Sampling random.seed(seed) sample_size = int(len(all_samples) * sample_ratio) if sample_size > 0: self.samples = random.sample(all_samples, sample_size) else: self.samples = all_samples # Fallback if ratio is too small but we want something print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] try: image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label, img_path except Exception as e: print(f"Error loading {img_path}: {e}") raise e