image2painting / dataset.py
Lasercatz
Upload 9 files
97bca33 verified
from torch.utils.data import Dataset
from torchvision import transforms
import glob
import os
from PIL import Image
class ImageNetDataset(Dataset):
def __init__(self, image_dir, resize_to_size: int):
self.image_dir = image_dir
self.transform = transforms.Compose([
transforms.Resize((resize_to_size, resize_to_size)),
transforms.ToTensor(),
])
self.image_files = sorted([
f for f in glob.glob(os.path.join(image_dir, "**", "*.*"), recursive=True)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
image = Image.open(img_path).convert("RGB")
# Crop to square size
width, height = image.size
crop_size = min(width, height)
center_crop = transforms.CenterCrop(crop_size)
image = center_crop(image)
# Then apply resize + normalization
image = self.transform(image)
return image