Keiser41 commited on
Commit
98166fe
1 Parent(s): 1f6c58d

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +86 -54
dataset/datasets.py CHANGED
@@ -8,103 +8,135 @@ from utils.utils import generate_mask
8
 
9
 
10
  class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform = None, mults_amount = 1):
12
- self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
-
17
  self.ToTensor = transforms.ToTensor()
 
18
  def __len__(self):
19
  return len(self.data)
20
-
21
  def __getitem__(self, idx):
22
- image_name = self.data[idx]
23
-
24
  try:
25
- color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
26
  except SyntaxError:
27
  print(f"Archivo {image_name} no es un PNG v谩lido. Saltando...")
28
  return None # O alguna otra acci贸n que prefieras
29
 
30
  if self.mults_amount > 1:
31
  mult_number = np.random.choice(range(self.mults_amount))
32
-
33
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
34
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
 
 
 
 
 
 
 
35
  else:
36
  bw_name = self.data[idx]
37
- dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
38
-
39
-
40
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
41
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
42
-
43
- bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
44
-
 
 
 
45
  if self.transform:
46
- result = self.transform(image = color_img, mask = bw_img)
47
- color_img = result['image']
48
- bw_img = result['mask']
49
-
50
  dfm_img = bw_img[:, :, 1]
51
  bw_img = bw_img[:, :, 0]
52
-
53
  color_img = self.ToTensor(color_img)
54
  bw_img = self.ToTensor(bw_img)
55
-
56
  dfm_img = self.ToTensor(dfm_img)
57
-
58
  color_img = (color_img - 0.5) / 0.5
59
-
60
  mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
61
  hint = torch.cat((color_img * mask, mask), 0)
62
-
63
  return bw_img, color_img, hint, dfm_img
64
-
 
65
  class FineTuningDataset(torch.utils.data.Dataset):
66
- def __init__(self, data_path, transform = None, mult_amount = 1):
67
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
68
- self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
 
 
 
 
69
  self.data_path = data_path
70
  self.transform = transform
71
  self.mults_amount = mult_amount
72
-
73
  np.random.shuffle(self.color_data)
74
-
75
  self.ToTensor = transforms.ToTensor()
 
76
  def __len__(self):
77
  return len(self.data)
78
-
79
  def __getitem__(self, idx):
80
- color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
81
-
 
 
82
  image_name = self.data[idx]
83
  if self.mults_amount > 1:
84
  mult_number = np.random.choice(range(self.mults_amount))
85
-
86
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
87
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
 
 
 
 
 
 
 
 
 
 
88
  else:
89
  bw_name = self.data[idx]
90
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
91
-
92
-
93
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
94
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
95
-
 
 
 
96
  if self.transform:
97
- result = self.transform(image = color_img)
98
- color_img = result['image']
99
-
100
- result = self.transform(image = bw_img, mask = dfm_img)
101
- bw_img = result['image']
102
- dfm_img = result['mask']
103
-
104
  color_img = self.ToTensor(color_img)
105
  bw_img = self.ToTensor(bw_img)
106
  dfm_img = self.ToTensor(dfm_img)
107
-
108
  color_img = (color_img - 0.5) / 0.5
109
-
110
  return bw_img, dfm_img, color_img
 
8
 
9
 
10
  class TrainDataset(torch.utils.data.Dataset):
11
+ def __init__(self, data_path, transform=None, mults_amount=1):
12
+ self.data = os.listdir(os.path.join(data_path, "color"))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
+
17
  self.ToTensor = transforms.ToTensor()
18
+
19
  def __len__(self):
20
  return len(self.data)
21
+
22
  def __getitem__(self, idx):
23
+ image_name = self.data[idx]
24
+
25
  try:
26
+ color_img = plt.imread(os.path.join(self.data_path, "color", image_name))
27
  except SyntaxError:
28
  print(f"Archivo {image_name} no es un PNG v谩lido. Saltando...")
29
  return None # O alguna otra acci贸n que prefieras
30
 
31
  if self.mults_amount > 1:
32
  mult_number = np.random.choice(range(self.mults_amount))
33
+
34
+ bw_name = (
35
+ image_name[: image_name.rfind(".")] + "_" + str(mult_number) + ".png"
36
+ )
37
+ dfm_name = (
38
+ image_name[: image_name.rfind(".")]
39
+ + "_"
40
+ + str(mult_number)
41
+ + "_dfm.png"
42
+ )
43
  else:
44
  bw_name = self.data[idx]
45
+ dfm_name = os.path.splitext(self.data[idx])[0] + "0_dfm.png"
46
+
47
+ bw_img = np.expand_dims(
48
+ plt.imread(os.path.join(self.data_path, "bw", bw_name)), 2
49
+ )
50
+ dfm_img = np.expand_dims(
51
+ plt.imread(os.path.join(self.data_path, "bw", dfm_name)), 2
52
+ )
53
+
54
+ bw_img = np.concatenate([bw_img, dfm_img], axis=2)
55
+
56
  if self.transform:
57
+ result = self.transform(image=color_img, mask=bw_img)
58
+ color_img = result["image"]
59
+ bw_img = result["mask"]
60
+
61
  dfm_img = bw_img[:, :, 1]
62
  bw_img = bw_img[:, :, 0]
63
+
64
  color_img = self.ToTensor(color_img)
65
  bw_img = self.ToTensor(bw_img)
66
+
67
  dfm_img = self.ToTensor(dfm_img)
68
+
69
  color_img = (color_img - 0.5) / 0.5
70
+
71
  mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
72
  hint = torch.cat((color_img * mask, mask), 0)
73
+
74
  return bw_img, color_img, hint, dfm_img
75
+
76
+
77
  class FineTuningDataset(torch.utils.data.Dataset):
78
+ def __init__(self, data_path, transform=None, mult_amount=1):
79
+ self.data = [
80
+ x
81
+ for x in os.listdir(os.path.join(data_path, "real_manga"))
82
+ if x.find("_dfm") == -1
83
+ ]
84
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, "color"))]
85
  self.data_path = data_path
86
  self.transform = transform
87
  self.mults_amount = mult_amount
88
+
89
  np.random.shuffle(self.color_data)
90
+
91
  self.ToTensor = transforms.ToTensor()
92
+
93
  def __len__(self):
94
  return len(self.data)
95
+
96
  def __getitem__(self, idx):
97
+ color_img = plt.imread(
98
+ os.path.join(self.data_path, "color", self.color_data[idx])
99
+ )
100
+
101
  image_name = self.data[idx]
102
  if self.mults_amount > 1:
103
  mult_number = np.random.choice(range(self.mults_amount))
104
+
105
+ bw_name = (
106
+ image_name[: image_name.rfind(".")]
107
+ + "_"
108
+ + str(self.mults_amount)
109
+ + ".png"
110
+ )
111
+ dfm_name = (
112
+ image_name[: image_name.rfind(".")]
113
+ + "_"
114
+ + str(self.mults_amount)
115
+ + "_dfm.png"
116
+ )
117
  else:
118
  bw_name = self.data[idx]
119
+ dfm_name = os.path.splitext(self.data[idx])[0] + "_dfm.png"
120
+
121
+ bw_img = np.expand_dims(
122
+ plt.imread(os.path.join(self.data_path, "real_manga", image_name)), 2
123
+ )
124
+ dfm_img = np.expand_dims(
125
+ plt.imread(os.path.join(self.data_path, "real_manga", dfm_name)), 2
126
+ )
127
+
128
  if self.transform:
129
+ result = self.transform(image=color_img)
130
+ color_img = result["image"]
131
+
132
+ result = self.transform(image=bw_img, mask=dfm_img)
133
+ bw_img = result["image"]
134
+ dfm_img = result["mask"]
135
+
136
  color_img = self.ToTensor(color_img)
137
  bw_img = self.ToTensor(bw_img)
138
  dfm_img = self.ToTensor(dfm_img)
139
+
140
  color_img = (color_img - 0.5) / 0.5
141
+
142
  return bw_img, dfm_img, color_img