Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| class HuggingFaceDataset(Dataset): | |
| def __init__(self, dataset_name, split='train', transform=None, sample_ratio=1.0, seed=42): | |
| """ | |
| Args: | |
| dataset_name (str): Name of the Hugging Face dataset (e.g., "Tungtom2004/Google_Nano_Banana_Edited_Images"). | |
| split (str): Dataset split to load (default: 'train'). | |
| transform (callable, optional): Optional transform to be applied on a sample. | |
| sample_ratio (float): Ratio of data to sample (0.0 to 1.0). | |
| seed (int): Random seed for reproducibility. | |
| """ | |
| self.dataset_name = dataset_name | |
| self.split = split | |
| self.transform = transform | |
| print(f"Loading Hugging Face dataset: {dataset_name} ({split})") | |
| try: | |
| self.hf_dataset = load_dataset(dataset_name, split=split) | |
| except Exception as e: | |
| print(f"Error loading dataset {dataset_name}: {e}") | |
| raise e | |
| # Sampling | |
| if sample_ratio < 1.0: | |
| print(f"Sampling {sample_ratio*100}% of the dataset...") | |
| self.hf_dataset = self.hf_dataset.shuffle(seed=seed).select(range(int(len(self.hf_dataset) * sample_ratio))) | |
| print(f"Loaded {len(self.hf_dataset)} samples.") | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| item = self.hf_dataset[idx] | |
| # Handle different dataset structures. | |
| # Assuming standard 'image' and 'label' keys, but might need adjustment based on specific dataset. | |
| # For "Tungtom2004/Google_Nano_Banana_Edited_Images", we need to inspect the structure. | |
| # Usually HF datasets have an 'image' column which is a PIL Image. | |
| if 'image' in item: | |
| image = item['image'] | |
| elif 'img' in item: | |
| image = item['img'] | |
| else: | |
| # Fallback: try to find the first key that looks like an image | |
| for key, value in item.items(): | |
| if isinstance(value, Image.Image): | |
| image = value | |
| break | |
| else: | |
| raise ValueError(f"Could not find image in dataset item keys: {item.keys()}") | |
| # Ensure RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Handle label | |
| # If label is not present, default to -1 or 0 depending on use case. | |
| # For this specific task, we might need to infer or just pass what's there. | |
| label = item.get('label', 1) # Default to 1 (fake) if not specified | |
| # Construct a pseudo-path or ID for tracking | |
| # HF datasets don't always have file paths. We can use the index or an ID field if available. | |
| img_path = f"hf_{self.dataset_name}_{self.split}_{idx}" | |
| if 'file_name' in item: | |
| img_path = item['file_name'] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label, img_path |