import os import torch from torch.utils.data import Dataset from datasets import load_dataset from PIL import Image import numpy as np import io class HuggingFaceDataset(Dataset): def __init__(self, dataset_name, split='train', transform=None, sample_ratio=1.0, seed=42): """ Args: dataset_name (str): Name of the Hugging Face dataset (e.g., "Tungtom2004/Google_Nano_Banana_Edited_Images"). split (str): Dataset split to load (default: 'train'). transform (callable, optional): Optional transform to be applied on a sample. sample_ratio (float): Ratio of data to sample (0.0 to 1.0). seed (int): Random seed for reproducibility. """ self.dataset_name = dataset_name self.split = split self.transform = transform print(f"Loading Hugging Face dataset: {dataset_name} ({split})") try: self.hf_dataset = load_dataset(dataset_name, split=split) except Exception as e: print(f"Error loading dataset {dataset_name}: {e}") raise e # Sampling if sample_ratio < 1.0: print(f"Sampling {sample_ratio*100}% of the dataset...") self.hf_dataset = self.hf_dataset.shuffle(seed=seed).select(range(int(len(self.hf_dataset) * sample_ratio))) print(f"Loaded {len(self.hf_dataset)} samples.") def __len__(self): return len(self.hf_dataset) def __getitem__(self, idx): item = self.hf_dataset[idx] # Handle different dataset structures. # Assuming standard 'image' and 'label' keys, but might need adjustment based on specific dataset. # For "Tungtom2004/Google_Nano_Banana_Edited_Images", we need to inspect the structure. # Usually HF datasets have an 'image' column which is a PIL Image. if 'image' in item: image = item['image'] elif 'img' in item: image = item['img'] else: # Fallback: try to find the first key that looks like an image for key, value in item.items(): if isinstance(value, Image.Image): image = value break else: raise ValueError(f"Could not find image in dataset item keys: {item.keys()}") # Ensure RGB if image.mode != 'RGB': image = image.convert('RGB') # Handle label # If label is not present, default to -1 or 0 depending on use case. # For this specific task, we might need to infer or just pass what's there. label = item.get('label', 1) # Default to 1 (fake) if not specified # Construct a pseudo-path or ID for tracking # HF datasets don't always have file paths. We can use the index or an ID field if available. img_path = f"hf_{self.dataset_name}_{self.split}_{idx}" if 'file_name' in item: img_path = item['file_name'] if self.transform: image = self.transform(image) return image, label, img_path