Update dataset/datasets.py
Browse files- dataset/datasets.py +6 -11
dataset/datasets.py
CHANGED
|
@@ -27,20 +27,14 @@ class TrainDataset(torch.utils.data.Dataset):
|
|
| 27 |
|
| 28 |
if self.mults_amount > 1:
|
| 29 |
mult_number = np.random.choice(range(self.mults_amount))
|
| 30 |
-
bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
|
| 31 |
-
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
|
| 32 |
else:
|
| 33 |
-
|
| 34 |
-
dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
dfm_img_path = os.path.join(self.bw_directory, dfm_name)
|
| 42 |
-
plt.imsave(bw_img_path, bw_img.squeeze(), cmap='gray')
|
| 43 |
-
plt.imsave(dfm_img_path, dfm_img.squeeze(), cmap='gray')
|
| 44 |
|
| 45 |
# Normalizaci贸n y generaci贸n de m谩scara
|
| 46 |
bw_img = self.ToTensor(bw_img)
|
|
@@ -54,6 +48,7 @@ class TrainDataset(torch.utils.data.Dataset):
|
|
| 54 |
|
| 55 |
return bw_img, color_img, hint, dfm_img
|
| 56 |
|
|
|
|
| 57 |
|
| 58 |
class FineTuningDataset(torch.utils.data.Dataset):
|
| 59 |
def __init__(self, data_path, transform=None, mult_amount=1):
|
|
|
|
| 27 |
|
| 28 |
if self.mults_amount > 1:
|
| 29 |
mult_number = np.random.choice(range(self.mults_amount))
|
|
|
|
|
|
|
| 30 |
else:
|
| 31 |
+
mult_number = 0
|
|
|
|
| 32 |
|
| 33 |
+
bw_name = f"{os.path.splitext(image_name)[0]}_{mult_number}.png"
|
| 34 |
+
dfm_name = f"{os.path.splitext(image_name)[0]}_{mult_number}_dfm.png"
|
| 35 |
|
| 36 |
+
bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
|
| 37 |
+
dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Normalizaci贸n y generaci贸n de m谩scara
|
| 40 |
bw_img = self.ToTensor(bw_img)
|
|
|
|
| 48 |
|
| 49 |
return bw_img, color_img, hint, dfm_img
|
| 50 |
|
| 51 |
+
# Resto del c贸digo...
|
| 52 |
|
| 53 |
class FineTuningDataset(torch.utils.data.Dataset):
|
| 54 |
def __init__(self, data_path, transform=None, mult_amount=1):
|