Spaces:
Sleeping
Sleeping
File size: 2,033 Bytes
37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
import random
from PIL import Image, ImageFile
from datasets import register
from torch.utils.data import Dataset
from torchvision import transforms
Image.MAX_IMAGE_PIXELS = 933120000
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp')
@register('image_folder')
class ImageFolder(Dataset):
def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False):
files = sorted(os.listdir(root_path))
self.files = [os.path.join(root_path, _) for _ in files if _.endswith(IMAGE_EXTS)]
self.resize = resize
self.square_crop = square_crop
self.rand_crop = rand_crop
self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
try:
image = Image.open(self.files[idx]).convert('RGB')
except:
print('Error loading image:', self.files[idx])
return self.__getitem__((idx + 1) % self.__len__())
if self.resize is not None:
r = self.resize
if isinstance(r, int):
w, h = image.size
if w < h:
r = (r, int(h / w * r))
else:
r = (int(w / h * r), r)
image = image.resize(r, Image.LANCZOS)
if self.square_crop:
w, h = image.size
l = min(w, h)
left, upper = (w - l) // 2, (h - l) // 2
image = image.crop((left, upper, left + l, upper + l))
if self.rand_crop is not None:
w, h = image.size
left = random.randint(0, w - self.rand_crop)
upper = random.randint(0, h - self.rand_crop)
image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop))
if self.rand_flip is not None:
image = self.rand_flip(image)
return image
|