Spaces:
Sleeping
Sleeping
File size: 3,028 Bytes
37163a6 2279ae0 37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import random
from PIL import Image, ImageFile
from datasets import register
from torch.utils.data import Dataset
from torchvision import transforms
import os
import random
Image.MAX_IMAGE_PIXELS = 933120000
ImageFile.LOAD_TRUNCATED_IMAGES = True
IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp')
@register('class_folder')
class ClassFolder(Dataset):
def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False):
folders = []
print('root_path', root_path)
for folder in sorted(os.listdir(root_path)):
print('folder', folder)
if os.path.isdir(os.path.join(root_path, folder)):
folders.append(os.path.join(root_path, folder))
print('folders', folders)
self.files = []
self.labels = []
for i, folder in enumerate(folders):
for file in sorted(os.listdir(os.path.join(root_path, folder))):
if file.endswith(IMAGE_EXTS):
self.files.append(os.path.join(root_path, folder, file))
self.labels.append(i)
self.resize = resize
self.square_crop = square_crop
self.rand_crop = rand_crop
self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None
self.n_classes = len(folders)
self.drop_label_p = drop_label_p
self.image_only = image_only
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
try:
image = Image.open(self.files[idx]).convert('RGB')
label = self.labels[idx]
except:
print('Error loading image:', self.files[idx])
return self.__getitem__((idx + 1) % self.__len__())
if self.resize is not None:
r = self.resize
if isinstance(r, int):
w, h = image.size
if w < h:
r = (r, int(h / w * r))
else:
r = (int(w / h * r), r)
image = image.resize(r, Image.LANCZOS)
if self.square_crop:
w, h = image.size
l = min(w, h)
left, upper = (w - l) // 2, (h - l) // 2
image = image.crop((left, upper, left + l, upper + l))
if self.rand_crop is not None:
w, h = image.size
left = random.randint(0, w - self.rand_crop)
upper = random.randint(0, h - self.rand_crop)
image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop))
if self.rand_flip is not None:
image = self.rand_flip(image)
if self.drop_label_p > 0.0 and random.random() < self.drop_label_p:
label = self.n_classes
if self.image_only:
return image
else:
return {
'image': image,
'class_labels': label,
}
|