Spaces:
Sleeping
Sleeping
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import glob | |
| import os | |
| from PIL import Image | |
| class ImageNetDataset(Dataset): | |
| def __init__(self, image_dir, resize_to_size: int): | |
| self.image_dir = image_dir | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((resize_to_size, resize_to_size)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.image_files = sorted([ | |
| f for f in glob.glob(os.path.join(image_dir, "**", "*.*"), recursive=True) | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg')) | |
| ]) | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| img_path = self.image_files[idx] | |
| image = Image.open(img_path).convert("RGB") | |
| # Crop to square size | |
| width, height = image.size | |
| crop_size = min(width, height) | |
| center_crop = transforms.CenterCrop(crop_size) | |
| image = center_crop(image) | |
| # Then apply resize + normalization | |
| image = self.transform(image) | |
| return image |