File size: 3,543 Bytes
b9d8bf0 1a260bf 5f240b7 b9d8bf0 35efc49 1a260bf 35efc49 5f240b7 35efc49 5f240b7 1a260bf 5f240b7 b9d8bf0 1a260bf 5f240b7 b9d8bf0 2e8d1a5 b9d8bf0 6b26f63 1a260bf 6b26f63 1a260bf 6b26f63 1a260bf b9d8bf0 f26a075 1a260bf f26a075 1a260bf f26a075 1a260bf f26a075 1a260bf 6b26f63 1a260bf f26a075 1a260bf f26a075 1a260bf f26a075 1a260bf f26a075 1a260bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from utils.utils import generate_mask
class TrainDataset(torch.utils.data.Dataset):
def __init__(self, data_path, transform=None, mults_amount=1):
self.data = os.listdir(os.path.join(data_path, 'color'))
self.data_path = data_path
self.transform = transform
self.mults_amount = mults_amount
self.ToTensor = transforms.ToTensor()
# Directorio para guardar las im谩genes en blanco y negro
self.bw_directory = os.path.join(data_path, 'bw')
if not os.path.exists(self.bw_directory):
os.makedirs(self.bw_directory)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image_name = self.data[idx]
color_img = plt.imread(os.path.join(self.data_path, 'color', image_name))
if self.mults_amount > 1:
mult_number = np.random.choice(range(self.mults_amount))
else:
mult_number = 0
bw_name = f"{os.path.splitext(image_name)[0]}_{mult_number}.png"
dfm_name = f"{os.path.splitext(image_name)[0]}_{mult_number}_dfm.png"
bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
# Normalizaci贸n y generaci贸n de m谩scara
bw_img = self.ToTensor(bw_img)
color_img = self.ToTensor(color_img)
dfm_img = self.ToTensor(dfm_img)
color_img = (color_img - 0.5) / 0.5
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
hint = torch.cat((color_img * mask, mask), 0)
return bw_img, color_img, hint, dfm_img
# Resto del c贸digo...
class FineTuningDataset(torch.utils.data.Dataset):
def __init__(self, data_path, transform=None, mult_amount=1):
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
self.data_path = data_path
self.transform = transform
self.mults_amount = mult_amount
self.ToTensor = transforms.ToTensor()
# Directorio para guardar las im谩genes en blanco y negro
self.bw_directory = os.path.join(data_path, 'bw')
if not os.path.exists(self.bw_directory):
os.makedirs(self.bw_directory)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
color_img = plt.imread(os.path.join(self.data_path, 'color', self.color_data[idx]))
image_name = self.data[idx]
if self.mults_amount > 1:
mult_number = np.random.choice(range(self.mults_amount))
bw_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '.png'
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(self.mults_amount) + '_dfm.png'
else:
bw_name = image_name
dfm_name = os.path.splitext(image_name)[0] + '_dfm.png'
bw_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, bw_name)), 2)
dfm_img = np.expand_dims(plt.imread(os.path.join(self.bw_directory, dfm_name)), 2)
# Normalizaci贸n
bw_img = self.ToTensor(bw_img)
color_img = self.ToTensor(color_img)
dfm_img = self.ToTensor(dfm_img)
color_img = (color_img - 0.5) / 0.5
return bw_img, dfm_img, color_img
|