Update dataset/datasets.py
Browse files- dataset/datasets.py +74 -4
dataset/datasets.py
CHANGED
|
@@ -28,10 +28,10 @@ class TrainDataset(torch.utils.data.Dataset):
|
|
| 28 |
mult_number = np.random.choice(range(self.mults_amount))
|
| 29 |
|
| 30 |
bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 31 |
-
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 32 |
else:
|
| 33 |
bw_name = self.data[idx]
|
| 34 |
-
dfm_name = os.path.splitext(self.data[idx])[0] + '
|
| 35 |
|
| 36 |
|
| 37 |
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
|
|
@@ -81,10 +81,10 @@ class FineTuningDataset(torch.utils.data.Dataset):
|
|
| 81 |
mult_number = np.random.choice(range(self.mults_amount))
|
| 82 |
|
| 83 |
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
|
| 84 |
-
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '
|
| 85 |
else:
|
| 86 |
bw_name = self.data[idx]
|
| 87 |
-
dfm_name = os.path.splitext(self.data[idx])[0] + '
|
| 88 |
|
| 89 |
|
| 90 |
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
|
|
@@ -104,4 +104,74 @@ class FineTuningDataset(torch.utils.data.Dataset):
|
|
| 104 |
|
| 105 |
color_img = (color_img - 0.5) / 0.5
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
return bw_img, dfm_img, color_img
|
|
|
|
| 28 |
mult_number = np.random.choice(range(self.mults_amount))
|
| 29 |
|
| 30 |
bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 31 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
|
| 32 |
else:
|
| 33 |
bw_name = self.data[idx]
|
| 34 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
|
| 35 |
|
| 36 |
|
| 37 |
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
|
|
|
|
| 81 |
mult_number = np.random.choice(range(self.mults_amount))
|
| 82 |
|
| 83 |
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
|
| 84 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
|
| 85 |
else:
|
| 86 |
bw_name = self.data[idx]
|
| 87 |
+
dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
|
| 88 |
|
| 89 |
|
| 90 |
bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
|
|
|
|
| 104 |
|
| 105 |
color_img = (color_img - 0.5) / 0.5
|
| 106 |
|
| 107 |
+
return bw_img, dfm_img, color_img
|
| 108 |
+
|
| 109 |
+
bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
|
| 110 |
+
|
| 111 |
+
if self.transform:
|
| 112 |
+
result = self.transform(image = color_img, mask = bw_img)
|
| 113 |
+
color_img = result['image']
|
| 114 |
+
bw_img = result['mask']
|
| 115 |
+
|
| 116 |
+
dfm_img = bw_img[:, :, 1]
|
| 117 |
+
bw_img = bw_img[:, :, 0]
|
| 118 |
+
|
| 119 |
+
color_img = self.ToTensor(color_img)
|
| 120 |
+
bw_img = self.ToTensor(bw_img)
|
| 121 |
+
|
| 122 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 123 |
+
|
| 124 |
+
color_img = (color_img - 0.5) / 0.5
|
| 125 |
+
|
| 126 |
+
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
|
| 127 |
+
hint = torch.cat((color_img * mask, mask), 0)
|
| 128 |
+
|
| 129 |
+
return bw_img, color_img, hint, dfm_img
|
| 130 |
+
|
| 131 |
+
class FineTuningDataset(torch.utils.data.Dataset):
|
| 132 |
+
def __init__(self, data_path, transform=None, mult_amount=1):
|
| 133 |
+
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
|
| 134 |
+
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
|
| 135 |
+
self.data_path = data_path
|
| 136 |
+
self.transform = transform
|
| 137 |
+
self.mults_amount = mult_amount
|
| 138 |
+
self.ToTensor = transforms.ToTensor()
|
| 139 |
+
|
| 140 |
+
# Directorio para guardar las imágenes en blanco y negro
|
| 141 |
+
self.bw_directory = os.path.join(data_path, 'bw')
|
| 142 |
+
if not os.path.exists(self.bw_directory):
|
| 143 |
+
os.makedirs(self.bw_directory)
|
| 144 |
+
|
| 145 |
+
def __len__(self):
|
| 146 |
+
return len(self.data)
|
| 147 |
+
|
| 148 |
+
def __getitem__(self, idx):
|
| 149 |
+
color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
|
| 150 |
+
|
| 151 |
+
image_name = self.data[idx]
|
| 152 |
+
if self.mults_amount > 1:
|
| 153 |
+
mult_number = np.random.choice(range(self.mults_amount))
|
| 154 |
+
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
|
| 155 |
+
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
|
| 156 |
+
else:
|
| 157 |
+
bw_name = image_name
|
| 158 |
+
dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
|
| 159 |
+
|
| 160 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
|
| 161 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
|
| 162 |
+
|
| 163 |
+
if self.transform:
|
| 164 |
+
result = self.transform(image = color_img)
|
| 165 |
+
color_img = result['image']
|
| 166 |
+
|
| 167 |
+
result = self.transform(image = bw_img, mask = dfm_img)
|
| 168 |
+
bw_img = result['image']
|
| 169 |
+
dfm_img = result['mask']
|
| 170 |
+
|
| 171 |
+
color_img = self.ToTensor(color_img)
|
| 172 |
+
bw_img = self.ToTensor(bw_img)
|
| 173 |
+
dfm_img = self.ToTensor(dfm_img)
|
| 174 |
+
|
| 175 |
+
color_img = (color_img - 0.5) / 0.5
|
| 176 |
+
|
| 177 |
return bw_img, dfm_img, color_img
|