Raid41 commited on
Commit
b9d8bf0
·
1 Parent(s): b6bb0b1

Update dataset/datasets.py

Browse files
Files changed (1) hide show
  1. dataset/datasets.py +84 -58
dataset/datasets.py CHANGED
@@ -1,81 +1,107 @@
1
- import os
2
  import torch
 
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
5
- from torch.utils.data import Dataset
6
- from PIL import Image
7
  import numpy as np
8
 
 
 
 
9
  class TrainDataset(torch.utils.data.Dataset):
10
- def __init__(self, data_path, transform=None, mults_amount=1):
11
  self.data = os.listdir(os.path.join(data_path, 'color'))
12
  self.data_path = data_path
13
  self.transform = transform
14
  self.mults_amount = mults_amount
 
15
  self.ToTensor = transforms.ToTensor()
16
-
17
- self.file_list = os.listdir(os.path.join(data_path, 'color'))
18
-
19
  def __len__(self):
20
- return len(self.file_list)
21
-
22
  def __getitem__(self, idx):
23
- image_name = self.file_list[idx]
24
-
25
- # Extract mult_number from the image name (last part before the extension)
26
- mult_number = int(image_name.split('_')[-1].split('.')[0])
27
-
28
- # Construct corresponding bw and dfm image names
29
- bw_name = f"{image_name.split('.')[0]}_{mult_number}_dfm.png"
30
- color_img_path = os.path.join(self.data_path, 'color', image_name)
31
- bw_img_path = os.path.join(self.data_path, 'bw', bw_name)
32
-
33
- color_img = Image.open(color_img_path).convert('RGB')
34
- bw_img = Image.open(bw_img_path).convert('L')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if self.transform:
37
- color_img = self.transform(color_img)
38
- bw_img = self.transform(bw_img)
39
-
40
- # Assuming generate_mask is a function generating a mask for hint
41
- mask = generate_mask(bw_img.size[1], bw_img.size[0])
42
- hint = torch.cat((bw_img * mask, mask), 0)
43
-
44
- return bw_img, color_img, hint
 
 
 
 
 
45
 
 
 
 
 
 
46
  class FineTuningDataset(torch.utils.data.Dataset):
47
- def __init__(self, data_path, transform=None, mult_amount=1):
48
- self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.endswith('_dfm.png')]
49
- print("Files in FineTuning Dataset:", self.data) # Debug print
50
- # rest of your code remains unchanged
51
-
52
  self.data_path = data_path
53
  self.transform = transform
54
- self.mult_amount = mult_amount
 
 
 
55
  self.ToTensor = transforms.ToTensor()
56
-
57
- self.file_list = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.endswith('_dfm.png')]
58
-
59
  def __len__(self):
60
- return len(self.file_list)
61
-
62
  def __getitem__(self, idx):
63
- image_name = self.file_list[idx]
64
-
65
- # Extract mult_number from the image name (last part before '_dfm.png')
66
- mult_number = int(image_name.split('_')[-2])
67
-
68
- # Construct corresponding bw and dfm image names
69
- bw_name = f"{image_name.split('_')[0]}_{mult_number}.png"
70
- color_img_path = os.path.join(self.data_path, 'color', bw_name)
71
- bw_img_path = os.path.join(self.data_path, 'real_manga', image_name)
72
-
73
- color_img = Image.open(color_img_path).convert('RGB')
74
- bw_img = Image.open(bw_img_path).convert('L')
75
-
 
 
 
76
  if self.transform:
77
- color_img = self.transform(color_img)
78
- bw_img = self.transform(bw_img)
79
-
80
- return bw_img, color_img
81
-
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import os
3
  import torchvision.transforms as transforms
4
  import matplotlib.pyplot as plt
 
 
5
  import numpy as np
6
 
7
+ from utils.utils import generate_mask
8
+
9
+
10
  class TrainDataset(torch.utils.data.Dataset):
11
+ def __init__(self, data_path, transform = None, mults_amount = 1):
12
  self.data = os.listdir(os.path.join(data_path, 'color'))
13
  self.data_path = data_path
14
  self.transform = transform
15
  self.mults_amount = mults_amount
16
+
17
  self.ToTensor = transforms.ToTensor()
 
 
 
18
  def __len__(self):
19
+ return len(self.data)
20
+
21
  def __getitem__(self, idx):
22
+ image_name = self.data[idx]
23
+
24
+ color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
25
+
 
 
 
 
 
 
 
 
26
 
27
+ if self.mults_amount > 1:
28
+ mult_number = np.random.choice(range(self.mults_amount))
29
+
30
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '.png'
31
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(mult_number) + '_dfm.png'
32
+ else:
33
+ bw_name = self.data[idx]
34
+ dfm_name = os.path.splitext(self.data[idx])[0] + '0_dfm.png'
35
+
36
+
37
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', bw_name)), 2)
38
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'bw', dfm_name)), 2)
39
+
40
+ bw_img = np.concatenate([bw_img, dfm_img], axis = 2)
41
+
42
  if self.transform:
43
+ result = self.transform(image = color_img, mask = bw_img)
44
+ color_img = result['image']
45
+ bw_img = result['mask']
46
+
47
+ dfm_img = bw_img[:, :, 1]
48
+ bw_img = bw_img[:, :, 0]
49
+
50
+ color_img = self.ToTensor(color_img)
51
+ bw_img = self.ToTensor(bw_img)
52
+
53
+ dfm_img = self.ToTensor(dfm_img)
54
+
55
+ color_img = (color_img - 0.5) / 0.5
56
 
57
+ mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
58
+ hint = torch.cat((color_img * mask, mask), 0)
59
+
60
+ return bw_img, color_img, hint, dfm_img
61
+
62
  class FineTuningDataset(torch.utils.data.Dataset):
63
+ def __init__(self, data_path, transform = None, mult_amount = 1):
64
+ self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
65
+ self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
 
 
66
  self.data_path = data_path
67
  self.transform = transform
68
+ self.mults_amount = mult_amount
69
+
70
+ np.random.shuffle(self.color_data)
71
+
72
  self.ToTensor = transforms.ToTensor()
 
 
 
73
  def __len__(self):
74
+ return len(self.data)
75
+
76
  def __getitem__(self, idx):
77
+ color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
78
+
79
+ image_name = self.data[idx]
80
+ if self.mults_amount > 1:
81
+ mult_number = np.random.choice(range(self.mults_amount))
82
+
83
+ bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
84
+ dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
85
+ else:
86
+ bw_name = self.data[idx]
87
+ dfm_name = os.path.splitext(self.data[idx])[0] + '_dfm.png'
88
+
89
+
90
+ bw_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', image_name)), 2)
91
+ dfm_img = np.expand_dims(plt.imread(os.path.join(self.data_path, 'real_manga', dfm_name)), 2)
92
+
93
  if self.transform:
94
+ result = self.transform(image = color_img)
95
+ color_img = result['image']
96
+
97
+ result = self.transform(image = bw_img, mask = dfm_img)
98
+ bw_img = result['image']
99
+ dfm_img = result['mask']
100
+
101
+ color_img = self.ToTensor(color_img)
102
+ bw_img = self.ToTensor(bw_img)
103
+ dfm_img = self.ToTensor(dfm_img)
104
+
105
+ color_img = (color_img - 0.5) / 0.5
106
+
107
+ return bw_img, dfm_img, color_img