| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import snowy | |
| import os | |
| def get_resized_image(img, size): | |
| if len(img.shape) == 2: | |
| img = np.repeat(np.expand_dims(img, 2), 3, 2) | |
| if (img.shape[0] < img.shape[1]): | |
| height = img.shape[0] | |
| ratio = height / size | |
| width = int(np.ceil(img.shape[1] / ratio)) | |
| img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA) | |
| else: | |
| width = img.shape[1] | |
| ratio = width / size | |
| height = int(np.ceil(img.shape[0] / ratio)) | |
| img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA) | |
| if (img.dtype == 'float32'): | |
| np.clip(img, 0, 1, out = img) | |
| return img | |
| def get_sketch_image(img, sketcher, mult_val): | |
| if mult_val: | |
| sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val) | |
| else: | |
| sketch_image = sketcher.get_sketch_with_resize(img) | |
| return sketch_image | |
| def get_dfm_image(sketch): | |
| dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze() | |
| return dfm_image | |
| def get_sketch(image, sketcher, dfm, mult = None): | |
| sketch_image = get_sketch_image(image, sketcher, mult) | |
| dfm_image = None | |
| if dfm: | |
| dfm_image = get_dfm_image(sketch_image) | |
| sketch_image = (sketch_image * 255).astype('uint8') | |
| if dfm: | |
| dfm_image = (dfm_image * 255).astype('uint8') | |
| return sketch_image, dfm_image | |
| def get_sketches(image, sketcher, mult_list, dfm): | |
| for mult in mult_list: | |
| yield get_sketch(image, sketcher, dfm, mult) | |
| def create_resized_dataset(source_path, target_path, side_size): | |
| images = os.listdir(source_path) | |
| for image_name in images: | |
| new_image_name = image_name[:image_name.rfind('.')] + '.png' | |
| new_path = os.path.join(target_path, new_image_name) | |
| if not os.path.exists(new_path): | |
| try: | |
| image = cv2.imread(os.path.join(source_path, image_name)) | |
| if image is None: | |
| raise Exception() | |
| image = get_resized_image(image, side_size) | |
| cv2.imwrite(new_path, image) | |
| except: | |
| print('Failed to process {}'.format(image_name)) | |
| def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False): | |
| images = os.listdir(source_path) | |
| for image_name in images: | |
| try: | |
| image = cv2.imread(os.path.join(source_path, image_name)) | |
| if image is None: | |
| raise Exception() | |
| for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)): | |
| new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' | |
| cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image) | |
| if dfm: | |
| dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' | |
| cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image) | |
| except: | |
| print('Failed to process {}'.format(image_name)) | |
| def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False): | |
| images = os.listdir(source_path) | |
| color_path = os.path.join(target_path, 'color') | |
| sketch_path = os.path.join(target_path, 'bw') | |
| if not os.path.exists(color_path): | |
| os.makedirs(color_path) | |
| if not os.path.exists(sketch_path): | |
| os.makedirs(sketch_path) | |
| for image_name in images: | |
| new_image_name = image_name[:image_name.rfind('.')] + '.png' | |
| try: | |
| image = cv2.imread(os.path.join(source_path, image_name)) | |
| if image is None: | |
| raise Exception() | |
| resized_image = get_resized_image(image, side_size) | |
| cv2.imwrite(os.path.join(color_path, new_image_name), resized_image) | |
| for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)): | |
| new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' | |
| cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image) | |
| if dfm: | |
| dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' | |
| cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image) | |
| except: | |
| print('Failed to process {}'.format(image_name)) | |