Raid41 commited on
Commit
6b26f63
1 Parent(s): 1a260bf

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +6 -11
dataset/datasets.py CHANGED
@@ -27,20 +27,14 @@ class TrainDataset(torch.utils.data.Dataset):
27
 
28
  if self.mults_amount > 1:
29
  mult_number = np.random.choice(range(self.mults_amount))
30
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
- bw_name = image_name
34
- dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
35
 
36
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
37
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
38
 
39
- # Guardar im谩genes en el directorio bw_directory
40
- bw_img_path = os.path.join(self.bw_directory, bw_name)
41
- dfm_img_path = os.path.join(self.bw_directory, dfm_name)
42
- plt.imsave(bw_img_path, bw_img.squeeze(), cmap='gray')
43
- plt.imsave(dfm_img_path, dfm_img.squeeze(), cmap='gray')
44
 
45
  # Normalizaci贸n y generaci贸n de m谩scara
46
  bw_img = self.ToTensor(bw_img)
@@ -54,6 +48,7 @@ class TrainDataset(torch.utils.data.Dataset):
54
 
55
  return bw_img, color_img, hint, dfm_img
56
 
 
57
 
58
  class FineTuningDataset(torch.utils.data.Dataset):
59
  def __init__(self, data_path, transform=None, mult_amount=1):
 
27
 
28
  if self.mults_amount > 1:
29
  mult_number = np.random.choice(range(self.mults_amount))
 
 
30
  else:
31
+ mult_number = 0
 
32
 
33
+ bw_name = f"{os.path.splitext(image_name)[0]}_{mult_number}.png"
34
+ dfm_name = f"{os.path.splitext(image_name)[0]}_{mult_number}_dfm.png"
35
 
36
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
37
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
 
 
 
38
 
39
  # Normalizaci贸n y generaci贸n de m谩scara
40
  bw_img = self.ToTensor(bw_img)
 
48
 
49
  return bw_img, color_img, hint, dfm_img
50
 
51
+ # Resto del c贸digo...
52
 
53
  class FineTuningDataset(torch.utils.data.Dataset):
54
  def __init__(self, data_path, transform=None, mult_amount=1):