Raid41 commited on
Commit
f26a075
·
1 Parent(s): a0561df

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +74 -4
dataset/datasets.py CHANGED
@@ -28,10 +28,10 @@ class TrainDataset(torch.utils.data.Dataset):
28
  mult_number = np.random.choice(range(self.mults_amount))
29
 
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
32
  else:
33
  bw_name = self.data[idx]
34
- dfm_name = os.path.splitext(self.data[idx])[0] + '0.png'
35
 
36
 
37
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
@@ -81,10 +81,10 @@ class FineTuningDataset(torch.utils.data.Dataset):
81
  mult_number = np.random.choice(range(self.mults_amount))
82
 
83
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + 'dfm.png'
85
  else:
86
  bw_name = self.data[idx]
87
- dfm_name = os.path.splitext(self.data[idx])[0] + 'dfm.png'
88
 
89
 
90
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
@@ -104,4 +104,74 @@ class FineTuningDataset(torch.utils.data.Dataset):
104
 
105
  color_img = (color_img - 0.5) / 0.5
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  return bw_img, dfm_img, color_img
 
28
  mult_number = np.random.choice(range(self.mults_amount))
29
 
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
  bw_name = self.data[idx]
34
+ dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
 
36
 
37
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
 
81
  mult_number = np.random.choice(range(self.mults_amount))
82
 
83
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
  else:
86
  bw_name = self.data[idx]
87
+ dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
 
89
 
90
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
 
104
 
105
  color_img = (color_img - 0.5) / 0.5
106
 
107
+ return bw_img, dfm_img, color_img
108
+
109
+ bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
110
+
111
+ if self.transform:
112
+ result = self.transform(image = color_img, mask = bw_img)
113
+ color_img = result['image']
114
+ bw_img = result['mask']
115
+
116
+ dfm_img = bw_img[:, :, 1]
117
+ bw_img = bw_img[:, :, 0]
118
+
119
+ color_img = self.ToTensor(color_img)
120
+ bw_img = self.ToTensor(bw_img)
121
+
122
+ dfm_img = self.ToTensor(dfm_img)
123
+
124
+ color_img = (color_img - 0.5) / 0.5
125
+
126
+ mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
127
+ hint = torch.cat((color_img * mask, mask), 0)
128
+
129
+ return bw_img, color_img, hint, dfm_img
130
+
131
+ class FineTuningDataset(torch.utils.data.Dataset):
132
+ def __init__(self, data_path, transform=None, mult_amount=1):
133
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
134
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
135
+ self.data_path = data_path
136
+ self.transform = transform
137
+ self.mults_amount = mult_amount
138
+ self.ToTensor = transforms.ToTensor()
139
+
140
+ # Directorio para guardar las imágenes en blanco y negro
141
+ self.bw_directory = os.path.join(data_path, 'bw')
142
+ if not os.path.exists(self.bw_directory):
143
+ os.makedirs(self.bw_directory)
144
+
145
+ def __len__(self):
146
+ return len(self.data)
147
+
148
+ def __getitem__(self, idx):
149
+ color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
150
+
151
+ image_name = self.data[idx]
152
+ if self.mults_amount > 1:
153
+ mult_number = np.random.choice(range(self.mults_amount))
154
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
155
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
156
+ else:
157
+ bw_name = image_name
158
+ dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
159
+
160
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
161
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
162
+
163
+ if self.transform:
164
+ result = self.transform(image = color_img)
165
+ color_img = result['image']
166
+
167
+ result = self.transform(image = bw_img, mask = dfm_img)
168
+ bw_img = result['image']
169
+ dfm_img = result['mask']
170
+
171
+ color_img = self.ToTensor(color_img)
172
+ bw_img = self.ToTensor(bw_img)
173
+ dfm_img = self.ToTensor(dfm_img)
174
+
175
+ color_img = (color_img - 0.5) / 0.5
176
+
177
  return bw_img, dfm_img, color_img