Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import json | |
| import random | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| class DragonDataset(Dataset): | |
| def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42): | |
| """ | |
| Args: | |
| root_dir (str): Path to the root directory (e.g., 'dataset/dragon/dragon_train_xs'). | |
| transform (callable, optional): Optional transform to be applied on a sample. | |
| sample_ratio (float): Ratio of data to sample (0.0 to 1.0). | |
| seed (int): Random seed for reproducibility. | |
| """ | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.samples = [] | |
| if not os.path.exists(root_dir): | |
| raise FileNotFoundError(f"Directory not found: {root_dir}") | |
| # Get all png files | |
| all_images = glob.glob(os.path.join(root_dir, '*.png')) | |
| all_samples = [] | |
| for img_path in all_images: | |
| # All images in Dragon dataset are generated (fake) | |
| label = 1 | |
| all_samples.append((img_path, label)) | |
| # Sampling | |
| random.seed(seed) | |
| sample_size = int(len(all_samples) * sample_ratio) | |
| if sample_size > 0: | |
| self.samples = random.sample(all_samples, sample_size) | |
| else: | |
| self.samples = all_samples # Fallback if ratio is too small but we want something | |
| print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| img_path, label = self.samples[idx] | |
| try: | |
| image = Image.open(img_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label, img_path | |
| except Exception as e: | |
| print(f"Error loading {img_path}: {e}") | |
| raise e |