Raid41 commited on
Commit
a0561df
·
1 Parent(s): ecbd64d

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +4 -4
dataset/datasets.py CHANGED
@@ -28,10 +28,10 @@ class TrainDataset(torch.utils.data.Dataset):
28
  mult_number = np.random.choice(range(self.mults_amount))
29
 
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
  bw_name = self.data[idx]
34
- dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
 
36
 
37
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
@@ -81,10 +81,10 @@ class FineTuningDataset(torch.utils.data.Dataset):
81
  mult_number = np.random.choice(range(self.mults_amount))
82
 
83
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
  else:
86
  bw_name = self.data[idx]
87
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
 
89
 
90
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
 
28
  mult_number = np.random.choice(range(self.mults_amount))
29
 
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
32
  else:
33
  bw_name = self.data[idx]
34
+ dfm_name = os.path.splitext(self.data[idx])[0] + '0.png'
35
 
36
 
37
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
 
81
  mult_number = np.random.choice(range(self.mults_amount))
82
 
83
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + 'dfm.png'
85
  else:
86
  bw_name = self.data[idx]
87
+ dfm_name = os.path.splitext(self.data[idx])[0] + 'dfm.png'
88
 
89
 
90
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)