Raid41 commited on
Commit
1a260bf
1 Parent(s): f26a075

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +33 -113
dataset/datasets.py CHANGED
@@ -1,133 +1,60 @@
1
- import torch
2
  import os
 
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
-
7
  from utils.utils import generate_mask
8
 
9
-
10
  class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform = None, mults_amount = 1):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
-
17
  self.ToTensor = transforms.ToTensor()
 
 
 
 
 
 
18
  def __len__(self):
19
  return len(self.data)
20
-
21
  def __getitem__(self, idx):
22
  image_name = self.data[idx]
23
-
24
  color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
-
26
 
27
  if self.mults_amount > 1:
28
  mult_number = np.random.choice(range(self.mults_amount))
29
-
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
- bw_name = self.data[idx]
34
- dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
-
36
-
37
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
-
40
- bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
-
42
- if self.transform:
43
- result = self.transform(image = color_img, mask = bw_img)
44
- color_img = result['image']
45
- bw_img = result['mask']
46
-
47
- dfm_img = bw_img[:, :, 1]
48
- bw_img = bw_img[:, :, 0]
49
-
50
- color_img = self.ToTensor(color_img)
51
- bw_img = self.ToTensor(bw_img)
52
-
53
- dfm_img = self.ToTensor(dfm_img)
54
-
55
- color_img = (color_img - 0.5) / 0.5
56
-
57
- mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
- hint = torch.cat((color_img * mask, mask), 0)
59
-
60
- return bw_img, color_img, hint, dfm_img
61
-
62
- class FineTuningDataset(torch.utils.data.Dataset):
63
- def __init__(self, data_path, transform = None, mult_amount = 1):
64
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
- self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
66
- self.data_path = data_path
67
- self.transform = transform
68
- self.mults_amount = mult_amount
69
-
70
- np.random.shuffle(self.color_data)
71
-
72
- self.ToTensor = transforms.ToTensor()
73
- def __len__(self):
74
- return len(self.data)
75
-
76
- def __getitem__(self, idx):
77
- color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
-
79
- image_name = self.data[idx]
80
- if self.mults_amount > 1:
81
- mult_number = np.random.choice(range(self.mults_amount))
82
-
83
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
- else:
86
- bw_name = self.data[idx]
87
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
-
89
-
90
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
-
93
- if self.transform:
94
- result = self.transform(image = color_img)
95
- color_img = result['image']
96
-
97
- result = self.transform(image = bw_img, mask = dfm_img)
98
- bw_img = result['image']
99
- dfm_img = result['mask']
100
-
101
- color_img = self.ToTensor(color_img)
102
  bw_img = self.ToTensor(bw_img)
103
- dfm_img = self.ToTensor(dfm_img)
104
-
105
- color_img = (color_img - 0.5) / 0.5
106
-
107
- return bw_img, dfm_img, color_img
108
-
109
- bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
110
-
111
- if self.transform:
112
- result = self.transform(image = color_img, mask = bw_img)
113
- color_img = result['image']
114
- bw_img = result['mask']
115
-
116
- dfm_img = bw_img[:, :, 1]
117
- bw_img = bw_img[:, :, 0]
118
-
119
  color_img = self.ToTensor(color_img)
120
- bw_img = self.ToTensor(bw_img)
121
-
122
  dfm_img = self.ToTensor(dfm_img)
123
-
124
  color_img = (color_img - 0.5) / 0.5
125
-
126
  mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
127
  hint = torch.cat((color_img * mask, mask), 0)
128
-
129
  return bw_img, color_img, hint, dfm_img
130
-
 
131
  class FineTuningDataset(torch.utils.data.Dataset):
132
  def __init__(self, data_path, transform=None, mult_amount=1):
133
  self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
@@ -159,19 +86,12 @@ class FineTuningDataset(torch.utils.data.Dataset):
159
 
160
  bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
161
  dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
162
-
163
- if self.transform:
164
- result = self.transform(image = color_img)
165
- color_img = result['image']
166
-
167
- result = self.transform(image = bw_img, mask = dfm_img)
168
- bw_img = result['image']
169
- dfm_img = result['mask']
170
-
171
- color_img = self.ToTensor(color_img)
172
  bw_img = self.ToTensor(bw_img)
 
173
  dfm_img = self.ToTensor(dfm_img)
174
-
175
  color_img = (color_img - 0.5) / 0.5
176
-
177
- return bw_img, dfm_img, color_img
 
 
1
  import os
2
+ import torch
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
  from utils.utils import generate_mask
7
 
 
8
  class TrainDataset(torch.utils.data.Dataset):
9
+ def __init__(self, data_path, transform=None, mults_amount=1):
10
  self.data = os.listdir(os.path.join(data_path, 'color'))
11
  self.data_path = data_path
12
  self.transform = transform
13
  self.mults_amount = mults_amount
 
14
  self.ToTensor = transforms.ToTensor()
15
+
16
+ # Directorio para guardar las im谩genes en blanco y negro
17
+ self.bw_directory = os.path.join(data_path, 'bw')
18
+ if not os.path.exists(self.bw_directory):
19
+ os.makedirs(self.bw_directory)
20
+
21
  def __len__(self):
22
  return len(self.data)
23
+
24
  def __getitem__(self, idx):
25
  image_name = self.data[idx]
 
26
  color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
 
27
 
28
  if self.mults_amount > 1:
29
  mult_number = np.random.choice(range(self.mults_amount))
 
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
+ bw_name = image_name
34
+ dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
35
+
36
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
37
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
38
+
39
+ # Guardar im谩genes en el directorio bw_directory
40
+ bw_img_path = os.path.join(self.bw_directory, bw_name)
41
+ dfm_img_path = os.path.join(self.bw_directory, dfm_name)
42
+ plt.imsave(bw_img_path, bw_img.squeeze(), cmap='gray')
43
+ plt.imsave(dfm_img_path, dfm_img.squeeze(), cmap='gray')
44
+
45
+ # Normalizaci贸n y generaci贸n de m谩scara
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  bw_img = self.ToTensor(bw_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  color_img = self.ToTensor(color_img)
 
 
48
  dfm_img = self.ToTensor(dfm_img)
49
+
50
  color_img = (color_img - 0.5) / 0.5
51
+
52
  mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
53
  hint = torch.cat((color_img * mask, mask), 0)
54
+
55
  return bw_img, color_img, hint, dfm_img
56
+
57
+
58
  class FineTuningDataset(torch.utils.data.Dataset):
59
  def __init__(self, data_path, transform=None, mult_amount=1):
60
  self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
 
86
 
87
  bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
88
  dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
89
+
90
+ # Normalizaci贸n
 
 
 
 
 
 
 
 
91
  bw_img = self.ToTensor(bw_img)
92
+ color_img = self.ToTensor(color_img)
93
  dfm_img = self.ToTensor(dfm_img)
94
+
95
  color_img = (color_img - 0.5) / 0.5
96
+
97
+ return bw_img, dfm_img, color_img