| from torch.utils import data |
| import os |
| import torch |
| import numpy as np |
| import cv2 |
| import random |
|
|
| class myDataset(data.Dataset): |
| """Custom data.Dataset compatible with data.DataLoader.""" |
|
|
| def __init__(self, train_data_dir): |
| self.img_path = os.path.join(train_data_dir, "hair") |
| self.pose_path = os.path.join(train_data_dir, "pose.npy") |
| self.non_hair_path = os.path.join(train_data_dir, "no_hair") |
| self.ref_path = os.path.join(train_data_dir, "ref_hair") |
| self.lists = os.listdir(self.img_path) |
| self.len = len(self.lists) |
| self.pose = np.load(self.pose_path) |
|
|
| def __getitem__(self, index): |
| """Returns one data pair (source and target).""" |
| |
| random_number1 = random.randrange(0, 12) |
| random_number2 = random.randrange(0, 12) |
|
|
| while random_number2 == random_number1: |
| random_number2 = random.randrange(0, 12) |
| name = self.lists[index] |
|
|
| |
| |
|
|
| random_number2 = random_number1 |
|
|
| hair_path = os.path.join(self.img_path, name, str(random_number1) + '.jpg') |
| non_hair_path = os.path.join(self.non_hair_path, name, str(random_number2) + '.jpg') |
| ref_folder = os.path.join(self.ref_path, name) |
|
|
| files = [f for f in os.listdir(ref_folder) if f.endswith('.jpg')] |
| ref_path = os.path.join(ref_folder, files[0]) |
| img_hair = cv2.imread(hair_path) |
| img_non_hair = cv2.imread(non_hair_path) |
| ref_hair = cv2.imread(ref_path) |
|
|
| img_hair = cv2.cvtColor(img_hair, cv2.COLOR_BGR2RGB) |
| img_non_hair = cv2.cvtColor(img_non_hair, cv2.COLOR_BGR2RGB) |
| ref_hair = cv2.cvtColor(ref_hair, cv2.COLOR_BGR2RGB) |
|
|
| img_hair = cv2.resize(img_hair, (512, 512)) |
| img_non_hair = cv2.resize(img_non_hair, (512, 512)) |
| ref_hair = cv2.resize(ref_hair, (512, 512)) |
|
|
| img_hair = (img_hair / 255.0) * 2 - 1 |
| img_non_hair = (img_non_hair / 255.0) * 2 - 1 |
| ref_hair = (ref_hair / 255.0) * 2 - 1 |
|
|
| img_hair = torch.tensor(img_hair).permute(2, 0, 1) |
| img_non_hair = torch.tensor(img_non_hair).permute(2, 0, 1) |
| ref_hair = torch.tensor(ref_hair).permute(2, 0, 1) |
|
|
| pose1 = self.pose[random_number1] |
| pose1 = torch.tensor(pose1) |
| pose2 = self.pose[random_number2] |
| pose2 = torch.tensor(pose2) |
|
|
| return { |
| 'hair_pose': pose1, |
| 'img_hair': img_hair, |
| 'bald_pose': pose2, |
| 'img_non_hair': img_non_hair, |
| 'ref_hair': ref_hair |
| } |
|
|
| def __len__(self): |
| return self.len |
|
|
|
|
| if __name__ == "__main__": |
|
|
| train_dataset = myDataset("./data") |
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=1, |
| num_workers=1, |
| ) |
|
|
| for epoch in range(0, len(train_dataset) + 1): |
| for step, batch in enumerate(train_dataloader): |
| print("batch[hair_pose]:", batch["hair_pose"]) |
| print("batch[img_hair]:", batch["img_hair"]) |
| print("batch[bald_pose]:", batch["bald_pose"]) |
| print("batch[img_non_hair]:", batch["img_non_hair"]) |
| print("batch[ref_hair]:", batch["ref_hair"]) |
|
|