| from pathlib import Path |
| import cv2 |
| import PIL |
| import numpy as np |
| import torch |
| import torch.utils |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| import glob |
|
|
| from .transforms.homographic_transforms import sample_homography |
| from kornia.geometry import warp_perspective,transform_points |
|
|
|
|
| homography_params = { |
| 'translation': True, |
| 'rotation': True, |
| 'scaling': True, |
| 'perspective': True, |
| 'scaling_amplitude': 0.2, |
| 'perspective_amplitude_x': 0.2, |
| 'perspective_amplitude_y': 0.2, |
| 'patch_ratio': 0.85, |
| 'max_angle': 1.57, |
| 'allow_artifacts': True |
| } |
|
|
| class Hybrid_Dataset(torch.utils.data.Dataset): |
| def __init__(self, datacfg=None, images_root=None, overwrite=False): |
| self.conf = datacfg |
| self.root = images_root |
|
|
| |
| |
|
|
| |
| |
| |
| self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg') |
| self.files.sort() |
|
|
| self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz') |
|
|
| self.size = (512, 512) |
|
|
| self.overwrite = overwrite |
|
|
| if len(self.files) == 0: |
| raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.') |
| |
| |
| for file in tqdm(self.files): |
| npz_file = Path(file).with_suffix('.npz') |
| if not npz_file.exists() or self.overwrite: |
| image = cv2.imread(file, 0) |
| image = cv2.resize(image, self.size) |
| image = np.array(image, dtype=np.float32)/255.0 |
|
|
| w, h = image.shape[:2] |
| H = sample_homography(self.size, **homography_params)[0] |
| warped_image = cv2.warpPerspective(image, H, self.size) |
| warped_image = np.array(warped_image, dtype=np.float32) |
|
|
| data = { |
| 'ref_image': image, |
| 'target_image': warped_image, |
| 'homo_mat': H, |
| } |
|
|
| np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H) |
|
|
| self.npz_files.append(npz_file) |
|
|
| def get_dataset(self): |
| return self.npz_files |
| |
| def get_images(self): |
| return self.files |
|
|
| def len_dataset(self): |
| return len(self.files) |
| |
| def __getitem__(self, idx): |
| npz_file = self.npz_files(idx) |
| data = np.load(npz_file) |
|
|
| return data |
|
|