File size: 2,033 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import random
from PIL import Image, ImageFile

from datasets import register
from torch.utils.data import Dataset
from torchvision import transforms


Image.MAX_IMAGE_PIXELS = 933120000
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp')


@register('image_folder')
class ImageFolder(Dataset):

    def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False):
        files = sorted(os.listdir(root_path))
        self.files = [os.path.join(root_path, _) for _ in files if _.endswith(IMAGE_EXTS)]
        
        self.resize = resize
        self.square_crop = square_crop
        self.rand_crop = rand_crop
        self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.files[idx]).convert('RGB')
        except:
            print('Error loading image:', self.files[idx])
            return self.__getitem__((idx + 1) % self.__len__())
        
        if self.resize is not None:
            r = self.resize
            if isinstance(r, int):
                w, h = image.size
                if w < h:
                    r = (r, int(h / w * r))
                else:
                    r = (int(w / h * r), r)
            image = image.resize(r, Image.LANCZOS)

        if self.square_crop:
            w, h = image.size
            l = min(w, h)
            left, upper = (w - l) // 2, (h - l) // 2
            image = image.crop((left, upper, left + l, upper + l))

        if self.rand_crop is not None:
            w, h = image.size
            left = random.randint(0, w - self.rand_crop)
            upper = random.randint(0, h - self.rand_crop)
            image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop))
        
        if self.rand_flip is not None:
            image = self.rand_flip(image)
        
        return image