dftest1 / src /data /dragon_dataset.py
akcanca's picture
Upload 110 files (#1)
07fe054 verified
import os
import glob
import json
import random
from PIL import Image
from torch.utils.data import Dataset
class DragonDataset(Dataset):
def __init__(self, root_dir, transform=None, sample_ratio=0.01, seed=42):
"""
Args:
root_dir (str): Path to the root directory (e.g., 'dataset/dragon/dragon_train_xs').
transform (callable, optional): Optional transform to be applied on a sample.
sample_ratio (float): Ratio of data to sample (0.0 to 1.0).
seed (int): Random seed for reproducibility.
"""
self.root_dir = root_dir
self.transform = transform
self.samples = []
if not os.path.exists(root_dir):
raise FileNotFoundError(f"Directory not found: {root_dir}")
# Get all png files
all_images = glob.glob(os.path.join(root_dir, '*.png'))
all_samples = []
for img_path in all_images:
# All images in Dragon dataset are generated (fake)
label = 1
all_samples.append((img_path, label))
# Sampling
random.seed(seed)
sample_size = int(len(all_samples) * sample_ratio)
if sample_size > 0:
self.samples = random.sample(all_samples, sample_size)
else:
self.samples = all_samples # Fallback if ratio is too small but we want something
print(f"Loaded {len(self.samples)} samples from {len(all_samples)} total images.")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
try:
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label, img_path
except Exception as e:
print(f"Error loading {img_path}: {e}")
raise e