Raid41 commited on
Commit
c340733
·
1 Parent(s): 1d9f16e

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +39 -24
dataset/datasets.py CHANGED
@@ -4,11 +4,8 @@ import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
 
7
- from utils.utils import generate_mask
8
-
9
-
10
  class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform=None, mults_amount=1):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
@@ -21,20 +18,21 @@ class TrainDataset(torch.utils.data.Dataset):
21
  def __getitem__(self, idx):
22
  image_name = self.data[idx]
23
 
24
- color_img = plt.imread(get_file_path(image_name, self.data_path, self.mults_amount))
25
 
26
 
27
  if self.mults_amount > 1:
28
  mult_number = np.random.choice(range(self.mults_amount))
 
29
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
30
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
31
  else:
32
  bw_name = self.data[idx]
33
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
34
 
35
 
36
- bw_img = np.expand_dims(plt.imread(get_file_path(bw_name, self.data_path, self.mults_amount)), 2)
37
- dfm_img = np.expand_dims(plt.imread(get_file_path(dfm_name, self.data_path, self.mults_amount)), 2)
38
 
39
  bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
40
 
@@ -53,24 +51,41 @@ class TrainDataset(torch.utils.data.Dataset):
53
 
54
  color_img = (color_img - 0.5) / 0.5
55
 
56
- mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
57
  hint = torch.cat((color_img * mask, mask), 0)
58
 
59
  return bw_img, color_img, hint, dfm_img
60
 
61
  class FineTuningDataset(torch.utils.data.Dataset):
62
- def __init__(self, data_path, transform=None, mult_amount=1):
63
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
64
- self.color
65
-
66
- bw_img = plt.imread(get_file_path(bw_name, self.data_path, self.mults_amount))
67
- dfm_img = plt.imread(get_file_path(dfm_name, self.data_path, self.mults_amount))
68
-
69
- bw_img = bw_img.astype('float32')
70
- dfm_img = dfm_img.astype('float32')
71
-
72
- color_img = bw_img[:, :, 0]
73
- hint = dfm_img
74
-
75
- return bw_img, color_img, hint, dfm_img
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
 
 
 
 
7
  class TrainDataset(torch.utils.data.Dataset):
8
+ def __init__(self, data_path, transform = None, mults_amount = 1):
9
  self.data = os.listdir(os.path.join(data_path, 'color'))
10
  self.data_path = data_path
11
  self.transform = transform
 
18
  def __getitem__(self, idx):
19
  image_name = self.data[idx]
20
 
21
+ color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
22
 
23
 
24
  if self.mults_amount > 1:
25
  mult_number = np.random.choice(range(self.mults_amount))
26
+
27
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
28
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
29
  else:
30
  bw_name = self.data[idx]
31
+ dfm_name = 'dfm_' + self.data[idx]
32
 
33
 
34
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
35
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
36
 
37
  bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
38
 
 
51
 
52
  color_img = (color_img - 0.5) / 0.5
53
 
54
+ mask = generate_mask(bw_img.shape[1])
55
  hint = torch.cat((color_img * mask, mask), 0)
56
 
57
  return bw_img, color_img, hint, dfm_img
58
 
59
  class FineTuningDataset(torch.utils.data.Dataset):
60
+ def __init__(self, data_path, transform = None):
61
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('dfm_') == -1] * 8
62
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color')) if x.find('left') == -1 and x.find('right') == -1] * 6
63
+ self.data_path = data_path
64
+ self.transform = transform
65
+
66
+ np.random.shuffle(self.color_data)
67
+
68
+ self.ToTensor = transforms.ToTensor()
69
+ def __len__(self):
70
+ return len(self.data)
71
+
72
+ def __getitem__(self, idx):
73
+ color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
74
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', self.data[idx])), 2)
75
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', 'dfm_' + self.data[idx])), 2)
76
+
77
+ if self.transform:
78
+ result = self.transform(image = color_img)
79
+ color_img = result['image']
80
+
81
+ result = self.transform(image = bw_img, mask = dfm_img)
82
+ bw_img = result['image']
83
+ dfm_img = result['mask']
84
+
85
+ color_img = self.ToTensor(color_img)
86
+ bw_img = self.ToTensor(bw_img)
87
+ dfm_img = self.ToTensor(dfm_img)
88
+
89
+ color_img = (color_img - 0.5) / 0.5
90
+
91
+ return bw_img, dfm_img, color_img