Raid41 commited on
Commit
2e8d1a5
·
1 Parent(s): 56ec4a2

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +56 -86
dataset/datasets.py CHANGED
@@ -1,107 +1,77 @@
1
- import torch
2
  import os
 
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
 
 
5
  import numpy as np
6
 
7
- from utils.utils import generate_mask
8
-
9
-
10
- class TrainDataset(torch.utils.data.Dataset):
11
- def __init__(self, data_path, transform = None, mults_amount = 1):
12
- self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
-
17
  self.ToTensor = transforms.ToTensor()
 
 
 
18
  def __len__(self):
19
- return len(self.data)
20
-
21
  def __getitem__(self, idx):
22
- image_name = self.data[idx]
23
-
24
- color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
-
26
-
27
- if self.mults_amount > 1:
28
- mult_number = np.random.choice(range(self.mults_amount))
29
-
30
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
- else:
33
- bw_name = self.data[idx]
34
- dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
-
36
-
37
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
-
40
- bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
-
42
  if self.transform:
43
- result = self.transform(image = color_img, mask = bw_img)
44
- color_img = result['image']
45
- bw_img = result['mask']
46
-
47
- dfm_img = bw_img[:, :, 1]
48
- bw_img = bw_img[:, :, 0]
49
-
50
- color_img = self.ToTensor(color_img)
51
- bw_img = self.ToTensor(bw_img)
52
-
53
- dfm_img = self.ToTensor(dfm_img)
54
-
55
- color_img = (color_img - 0.5) / 0.5
56
-
57
- mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
- hint = torch.cat((color_img * mask, mask), 0)
59
-
60
  return bw_img, color_img, hint, dfm_img
61
-
62
- class FineTuningDataset(torch.utils.data.Dataset):
63
- def __init__(self, data_path, transform = None, mult_amount = 1):
64
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
- self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
66
  self.data_path = data_path
67
  self.transform = transform
68
  self.mults_amount = mult_amount
69
-
70
- np.random.shuffle(self.color_data)
71
-
72
  self.ToTensor = transforms.ToTensor()
 
 
 
73
  def __len__(self):
74
- return len(self.data)
75
-
76
  def __getitem__(self, idx):
77
- color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
-
79
- image_name = self.data[idx]
80
- if self.mults_amount > 1:
81
- mult_number = np.random.choice(range(self.mults_amount))
82
-
83
- bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
- dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
- else:
86
- bw_name = self.data[idx]
87
- dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
-
89
-
90
- bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
- dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
-
93
  if self.transform:
94
- result = self.transform(image = color_img)
95
- color_img = result['image']
96
-
97
- result = self.transform(image = bw_img, mask = dfm_img)
98
- bw_img = result['image']
99
- dfm_img = result['mask']
100
-
101
- color_img = self.ToTensor(color_img)
102
- bw_img = self.ToTensor(bw_img)
103
- dfm_img = self.ToTensor(dfm_img)
104
-
105
- color_img = (color_img - 0.5) / 0.5
106
-
107
  return bw_img, dfm_img, color_img
 
 
1
  import os
2
+ import torch
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
  import numpy as np
8
 
9
+ class TrainDataset(Dataset):
10
+ def __init__(self, data_path, transform=None, mults_amount=1):
 
 
 
 
11
  self.data_path = data_path
12
  self.transform = transform
13
  self.mults_amount = mults_amount
 
14
  self.ToTensor = transforms.ToTensor()
15
+
16
+ self.file_list = os.listdir(os.path.join(data_path, 'color'))
17
+
18
  def __len__(self):
19
+ return len(self.file_list)
20
+
21
  def __getitem__(self, idx):
22
+ image_name = self.file_list[idx]
23
+
24
+ # Extract mult_number from the image name
25
+ mult_number = int(image_name.split('_')[-1].split('.')[0])
26
+
27
+ # Construct corresponding bw and dfm image names
28
+ bw_name = image_name.replace(f'_{mult_number}.png', '.png')
29
+ dfm_name = bw_name.replace('.png', f'_{mult_number}_dfm.png')
30
+
31
+ color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
32
+ bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
33
+ dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
34
+
 
 
 
 
 
 
 
35
  if self.transform:
36
+ color_img = self.transform(color_img)
37
+ bw_img = self.transform(bw_img)
38
+ dfm_img = self.transform(dfm_img)
39
+
40
+ # Assuming generate_mask is a function generating a mask for hint
41
+ mask = generate_mask(bw_img.size[1], bw_img.size[0])
42
+ hint = torch.cat((bw_img * mask, mask), 0)
43
+
 
 
 
 
 
 
 
 
 
44
  return bw_img, color_img, hint, dfm_img
45
+
46
+ class FineTuningDataset(Dataset):
47
+ def __init__(self, data_path, transform=None, mult_amount=1):
 
 
48
  self.data_path = data_path
49
  self.transform = transform
50
  self.mults_amount = mult_amount
 
 
 
51
  self.ToTensor = transforms.ToTensor()
52
+
53
+ self.file_list = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
54
+
55
  def __len__(self):
56
+ return len(self.file_list)
57
+
58
  def __getitem__(self, idx):
59
+ image_name = self.file_list[idx]
60
+
61
+ # Extract mult_number from the image name
62
+ mult_number = int(image_name.split('_')[-1])
63
+
64
+ # Construct corresponding bw and dfm image names
65
+ bw_name = f"{image_name.split('_')[0]}_{mult_number}.png"
66
+ dfm_name = f"{image_name.split('_')[0]}_{mult_number}_dfm.png"
67
+
68
+ color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
69
+ bw_img = Image.open(os.path.join(self.data_path, 'real_manga', bw_name)).convert('L')
70
+ dfm_img = Image.open(os.path.join(self.data_path, 'real_manga', dfm_name)).convert('L')
71
+
 
 
 
72
  if self.transform:
73
+ color_img = self.transform(color_img)
74
+ bw_img = self.transform(bw_img)
75
+ dfm_img = self.transform(dfm_img)
76
+
 
 
 
 
 
 
 
 
 
77
  return bw_img, dfm_img, color_img