Raid41 commited on
Commit
685aa20
·
1 Parent(s): 45c435a

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +30 -78
dataset/datasets.py CHANGED
@@ -1,107 +1,59 @@
1
  import torch
2
- import os
3
  import torchvision.transforms as transforms
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
-
7
- from utils.utils import generate_mask
8
 
9
 
10
  class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform = None, mults_amount = 1):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
-
17
  self.ToTensor = transforms.ToTensor()
 
 
18
  def __len__(self):
19
  return len(self.data)
20
-
21
  def __getitem__(self, idx):
22
  image_name = self.data[idx]
23
-
24
  color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
-
26
 
27
  if self.mults_amount > 1:
28
  mult_number = np.random.choice(range(self.mults_amount))
29
-
30
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
  else:
33
  bw_name = self.data[idx]
34
- dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
-
36
-
37
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
  dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
-
40
  bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
-
42
  if self.transform:
43
  result = self.transform(image = color_img, mask = bw_img)
44
  color_img = result['image']
45
  bw_img = result['mask']
46
-
47
- dfm_img = bw_img[:, :, 1]
48
- bw_img = bw_img[:, :, 0]
49
-
50
- color_img = self.ToTensor(color_img)
51
- bw_img = self.ToTensor(bw_img)
52
-
53
- dfm_img = self.ToTensor(dfm_img)
54
-
55
- color_img = (color_img - 0.5) / 0.5
56
-
57
- mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
- hint = torch.cat((color_img * mask, mask), 0)
59
-
60
- return bw_img, color_img, hint, dfm_img
61
-
62
- class FineTuningDataset(torch.utils.data.Dataset):
63
- def __init__(self, data_path, transform = None, mult_amount = 1):
64
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
- self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
66
- self.data_path = data_path
67
- self.transform = transform
68
- self.mults_amount = mult_amount
69
-
70
- np.random.shuffle(self.color_data)
71
-
72
- self.ToTensor = transforms.ToTensor()
73
- def __len__(self):
74
- return len(self.data)
75
-
76
- def __getitem__(self, idx):
77
- color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
-
79
- image_name = self.data[idx]
80
- if self.mults_amount > 1:
81
- mult_number = np.random.choice(range(self.mults_amount))
82
-
83
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
- else:
86
- bw_name = self.data[idx]
87
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
-
89
-
90
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
-
93
- if self.transform:
94
- result = self.transform(image = color_img)
95
- color_img = result['image']
96
-
97
- result = self.transform(image = bw_img, mask = dfm_img)
98
- bw_img = result['image']
99
- dfm_img = result['mask']
100
-
101
- color_img = self.ToTensor(color_img)
102
- bw_img = self.ToTensor(bw_img)
103
- dfm_img = self.ToTensor(dfm_img)
104
-
105
- color_img = (color_img - 0.5) / 0.5
106
-
107
- return bw_img, dfm_img, color_img
 
1
  import torch
 
2
  import torchvision.transforms as transforms
3
+ from torch.utils.data import DataLoader, Sampler
 
 
 
4
 
5
 
6
  class TrainDataset(torch.utils.data.Dataset):
7
+ def __init__(self, data_path, transform=None, mults_amount=1):
8
  self.data = os.listdir(os.path.join(data_path, 'color'))
9
  self.data_path = data_path
10
  self.transform = transform
11
  self.mults_amount = mults_amount
12
+
13
  self.ToTensor = transforms.ToTensor()
14
+ self.Normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
15
+
16
  def __len__(self):
17
  return len(self.data)
18
+
19
  def __getitem__(self, idx):
20
  image_name = self.data[idx]
21
+
22
  color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
 
23
 
24
  if self.mults_amount > 1:
25
  mult_number = np.random.choice(range(self.mults_amount))
26
+
27
  bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
28
  dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
29
  else:
30
  bw_name = self.data[idx]
31
+ dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
32
+
 
33
  bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
34
  dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
35
+
36
  bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
37
+
38
  if self.transform:
39
  result = self.transform(image = color_img, mask = bw_img)
40
  color_img = result['image']
41
  bw_img = result['mask']
42
+
43
+ bw_img = self.Normalize(bw_img)
44
+ color_img = self.Normalize(color_img)
45
+
46
+ return bw_img, color_img, dfm_img
47
+
48
+
49
+ def main():
50
+ train_dataset = TrainDataset(data_path='./train', transform=transforms.ToTensor())
51
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, sampler=torch.utils.data.RandomSampler(train_dataset))
52
+
53
+ for batch in train_loader:
54
+ bw_images, color_images, dfm_images = batch
55
+ print(bw_images.shape, color_images.shape, dfm_images.shape)
56
+
57
+
58
+ if __name__ == '__main__':
59
+ main()