Raid41 commited on
Commit
09800e1
·
1 Parent(s): c17106f

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +3 -10
dataset/datasets.py CHANGED
@@ -8,11 +8,10 @@ import numpy as np
8
  from utils.utils import generate_mask
9
 
10
  class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform=None, mults_amount=1):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
- self.mults_amount = mults_amount
16
  self.ToTensor = transforms.ToTensor()
17
 
18
  def __len__(self):
@@ -23,14 +22,8 @@ class TrainDataset(torch.utils.data.Dataset):
23
 
24
  color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
25
 
26
- if self.mults_amount > 1:
27
- mult_number = np.random.choice(range(self.mults_amount))
28
-
29
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
30
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
31
- else:
32
- bw_name = self.data[idx]
33
- dfm_name = 'dfm_' + self.data[idx]
34
 
35
  bw_img = np.expand_dims(np.array(Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')), 2)
36
  dfm_img = np.expand_dims(np.array(Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')), 2)
 
8
  from utils.utils import generate_mask
9
 
10
  class TrainDataset(torch.utils.data.Dataset):
11
+ def __init__(self, data_path, transform=None):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
 
15
  self.ToTensor = transforms.ToTensor()
16
 
17
  def __len__(self):
 
22
 
23
  color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
24
 
25
+ bw_name = image_name
26
+ dfm_name = 'dfm_' + image_name
 
 
 
 
 
 
27
 
28
  bw_img = np.expand_dims(np.array(Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')), 2)
29
  dfm_img = np.expand_dims(np.array(Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')), 2)