Spaces:
Sleeping
Sleeping
File size: 4,156 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from torch.utils.data import Dataset
from PIL import Image
from utils import data_utils
from torchvision import transforms
class ImageDataset(Dataset):
def __init__(self, root, transform=None):
self.paths = sorted(data_utils.make_dataset(root))
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
image = Image.open(path).convert("RGB")
if self.transform:
image = self.transform(image)
return image
class CelebaAttributeDataset(Dataset):
def __init__(self, images_root, attr, transform=None, attributes_root="", use_attr=True):
self.paths = data_utils.make_dataset(images_root)
self.transform = transform
with open(attributes_root, "r") as f:
lines = f.readlines()
attr_num = -1
for i, data_attr in enumerate(lines[1].split(" ")):
if data_attr.strip() == attr.strip():
attr_num = i
break
assert attr_num > -1, f"Can not find attribute {attr}"
filtred_paths = []
for path in self.paths:
pic_num = int(path.split("/")[-1].replace(".jpg", "").replace(".png", ""))
pic_attrs = lines[pic_num + 2].strip().split(" ")
pic_attrs = pic_attrs[2:]
if use_attr and pic_attrs[attr_num] == "1" or not use_attr and pic_attrs[attr_num] == "-1":
filtred_paths.append(path)
self.paths = sorted(filtred_paths)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
from_path = self.paths[index]
image = Image.open(from_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image
class FIDDataset(Dataset):
def __init__(self, files, transforms=None):
self.files = files
self.transforms = transforms
def __len__(self):
return len(self.files)
def __getitem__(self, i):
file = self.files[i]
image = file.convert("RGB")
if self.transforms is not None:
image = self.transforms(image)
return image
class MetricsPathsDataset(Dataset):
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None, return_path=False, ignore=[]):
self.pairs = []
self.paths = []
self.names = []
for f in os.listdir(root_path):
if f not in ignore:
self.names.append(f)
image_path = os.path.join(root_path, f)
gt_path = os.path.join(gt_dir, f)
if f.endswith(".jpg") or f.endswith(".png"):
self.pairs.append([image_path, gt_path.replace(".png", ".jpg"), None])
self.paths.append(image_path)
self.transform = transform
self.transform_train = transform_train
self.return_path = return_path
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
from_path, to_path, _ = self.pairs[index]
from_im = Image.open(from_path).convert("RGB")
to_im = Image.open(to_path).convert("RGB")
if self.transform:
to_im = self.transform(to_im)
from_im = self.transform(from_im)
if not self.return_path:
return from_im, to_im
else:
return from_im, to_im, self.names[index]
class MetricsDataDataset(Dataset):
def __init__(
self, paths, target_data, fake_data, transform=None, transform_train=None
):
self.fake_data = fake_data
self.target_data = target_data
self.paths = paths
self.transform = transform
self.transform_train = transform_train
def __len__(self):
return len(self.fake_data)
def __getitem__(self, index):
target_im = self.target_data[index]
fake_im = self.fake_data[index]
if self.transform:
fake_im = self.transform(fake_im)
target_im = self.transform(target_im)
return target_im, fake_im
|