dftest1 / src /data /huggingface_dataset.py
akcanca's picture
Upload 110 files (#1)
07fe054 verified
import os
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from PIL import Image
import numpy as np
import io
class HuggingFaceDataset(Dataset):
def __init__(self, dataset_name, split='train', transform=None, sample_ratio=1.0, seed=42):
"""
Args:
dataset_name (str): Name of the Hugging Face dataset (e.g., "Tungtom2004/Google_Nano_Banana_Edited_Images").
split (str): Dataset split to load (default: 'train').
transform (callable, optional): Optional transform to be applied on a sample.
sample_ratio (float): Ratio of data to sample (0.0 to 1.0).
seed (int): Random seed for reproducibility.
"""
self.dataset_name = dataset_name
self.split = split
self.transform = transform
print(f"Loading Hugging Face dataset: {dataset_name} ({split})")
try:
self.hf_dataset = load_dataset(dataset_name, split=split)
except Exception as e:
print(f"Error loading dataset {dataset_name}: {e}")
raise e
# Sampling
if sample_ratio < 1.0:
print(f"Sampling {sample_ratio*100}% of the dataset...")
self.hf_dataset = self.hf_dataset.shuffle(seed=seed).select(range(int(len(self.hf_dataset) * sample_ratio)))
print(f"Loaded {len(self.hf_dataset)} samples.")
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
item = self.hf_dataset[idx]
# Handle different dataset structures.
# Assuming standard 'image' and 'label' keys, but might need adjustment based on specific dataset.
# For "Tungtom2004/Google_Nano_Banana_Edited_Images", we need to inspect the structure.
# Usually HF datasets have an 'image' column which is a PIL Image.
if 'image' in item:
image = item['image']
elif 'img' in item:
image = item['img']
else:
# Fallback: try to find the first key that looks like an image
for key, value in item.items():
if isinstance(value, Image.Image):
image = value
break
else:
raise ValueError(f"Could not find image in dataset item keys: {item.keys()}")
# Ensure RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Handle label
# If label is not present, default to -1 or 0 depending on use case.
# For this specific task, we might need to infer or just pass what's there.
label = item.get('label', 1) # Default to 1 (fake) if not specified
# Construct a pseudo-path or ID for tracking
# HF datasets don't always have file paths. We can use the index or an ID field if available.
img_path = f"hf_{self.dataset_name}_{self.split}_{idx}"
if 'file_name' in item:
img_path = item['file_name']
if self.transform:
image = self.transform(image)
return image, label, img_path