Spaces:
Running
Running
| import os | |
| import cv2 | |
| import json | |
| import random | |
| import datetime | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| class DataGen: | |
| def __init__(self, path, split_ratio, x, y, color_space='rgb'): | |
| self.x = x | |
| self.y = y | |
| self.path = path | |
| self.color_space = color_space | |
| self.path_train_images = path + "train/images/" | |
| self.path_train_labels = path + "train/labels/" | |
| self.path_test_images = path + "test/images/" | |
| self.path_test_labels = path + "test/labels/" | |
| self.image_file_list = get_png_filename_list(self.path_train_images) | |
| self.label_file_list = get_png_filename_list(self.path_train_labels) | |
| self.image_file_list[:], self.label_file_list[:] = self.shuffle_image_label_lists_together() | |
| self.split_index = int(split_ratio * len(self.image_file_list)) | |
| self.x_train_file_list = self.image_file_list[self.split_index:] | |
| self.y_train_file_list = self.label_file_list[self.split_index:] | |
| self.x_val_file_list = self.image_file_list[:self.split_index] | |
| self.y_val_file_list = self.label_file_list[:self.split_index] | |
| self.x_test_file_list = get_png_filename_list(self.path_test_images) | |
| self.y_test_file_list = get_png_filename_list(self.path_test_labels) | |
| def generate_data(self, batch_size, train=False, val=False, test=False): | |
| """Replaces Keras' native ImageDataGenerator.""" | |
| try: | |
| if train is True: | |
| image_file_list = self.x_train_file_list | |
| label_file_list = self.y_train_file_list | |
| elif val is True: | |
| image_file_list = self.x_val_file_list | |
| label_file_list = self.y_val_file_list | |
| elif test is True: | |
| image_file_list = self.x_test_file_list | |
| label_file_list = self.y_test_file_list | |
| except ValueError: | |
| print('one of train or val or test need to be True') | |
| i = 0 | |
| while True: | |
| image_batch = [] | |
| label_batch = [] | |
| for b in range(batch_size): | |
| if i == len(self.x_train_file_list): | |
| i = 0 | |
| if i < len(image_file_list): | |
| sample_image_filename = image_file_list[i] | |
| sample_label_filename = label_file_list[i] | |
| # print('image: ', image_file_list[i]) | |
| # print('label: ', label_file_list[i]) | |
| if train or val: | |
| image = cv2.imread(self.path_train_images + sample_image_filename, 1) | |
| label = cv2.imread(self.path_train_labels + sample_label_filename, 0) | |
| elif test is True: | |
| image = cv2.imread(self.path_test_images + sample_image_filename, 1) | |
| label = cv2.imread(self.path_test_labels + sample_label_filename, 0) | |
| # image, label = self.change_color_space(image, label, self.color_space) | |
| label = np.expand_dims(label, axis=2) | |
| if image.shape[0] == self.x and image.shape[1] == self.y: | |
| image_batch.append(image.astype("float32")) | |
| else: | |
| print('the input image shape is not {}x{}'.format(self.x, self.y)) | |
| if label.shape[0] == self.x and label.shape[1] == self.y: | |
| label_batch.append(label.astype("float32")) | |
| else: | |
| print('the input label shape is not {}x{}'.format(self.x, self.y)) | |
| i += 1 | |
| if image_batch and label_batch: | |
| image_batch = normalize(np.array(image_batch)) | |
| label_batch = normalize(np.array(label_batch)) | |
| yield (image_batch, label_batch) | |
| def get_num_data_points(self, train=False, val=False): | |
| try: | |
| image_file_list = self.x_train_file_list if val is False and train is True else self.x_val_file_list | |
| except ValueError: | |
| print('one of train or val need to be True') | |
| return len(image_file_list) | |
| def shuffle_image_label_lists_together(self): | |
| combined = list(zip(self.image_file_list, self.label_file_list)) | |
| random.shuffle(combined) | |
| return zip(*combined) | |
| def change_color_space(image, label, color_space): | |
| color_space = color_space.lower() | |
| if color_space == 'hsi' or color_space == 'hsv': | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | |
| label = cv2.cvtColor(label, cv2.COLOR_BGR2HSV) | |
| elif color_space == 'lab': | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| label = cv2.cvtColor(label, cv2.COLOR_BGR2LAB) | |
| return image, label | |
| def normalize(arr): | |
| diff = np.amax(arr) - np.amin(arr) | |
| diff = 255 if diff == 0 else diff | |
| arr = arr / np.absolute(diff) | |
| return arr | |
| def get_png_filename_list(path): | |
| file_list = [] | |
| for FileNameLength in range(0, 500): | |
| for dirName, subdirList, fileList in os.walk(path): | |
| for filename in fileList: | |
| # check file extension | |
| if ".png" in filename.lower() and len(filename) == FileNameLength: | |
| file_list.append(filename) | |
| break | |
| file_list.sort() | |
| return file_list | |
| def get_jpg_filename_list(path): | |
| file_list = [] | |
| for FileNameLength in range(0, 500): | |
| for dirName, subdirList, fileList in os.walk(path): | |
| for filename in fileList: | |
| # check file extension | |
| if ".jpg" in filename.lower() and len(filename) == FileNameLength: | |
| file_list.append(filename) | |
| break | |
| file_list.sort() | |
| return file_list | |
| def load_jpg_images(path): | |
| file_list = get_jpg_filename_list(path) | |
| temp_list = [] | |
| for filename in file_list: | |
| img = cv2.imread(path + filename, 1) | |
| temp_list.append(img.astype("float32")) | |
| temp_list = np.array(temp_list) | |
| # x_train = np.reshape(x_train,(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)) | |
| return temp_list, file_list | |
| def load_png_images(path): | |
| temp_list = [] | |
| file_list = get_png_filename_list(path) | |
| for filename in file_list: | |
| img = cv2.imread(path + filename, 1) | |
| temp_list.append(img.astype("float32")) | |
| temp_list = np.array(temp_list) | |
| #temp_list = np.reshape(temp_list,(temp_list.shape[0], temp_list.shape[1], temp_list.shape[2], 3)) | |
| return temp_list, file_list | |
| def load_data(path): | |
| # path_train_images = path + "train/images/padded/" | |
| # path_train_labels = path + "train/labels/padded/" | |
| # path_test_images = path + "test/images/padded/" | |
| # path_test_labels = path + "test/labels/padded/" | |
| path_train_images = path + "train/images/" | |
| path_train_labels = path + "train/labels/" | |
| path_test_images = path + "test/images/" | |
| path_test_labels = path + "test/labels/" | |
| x_train, train_image_filenames_list = load_png_images(path_train_images) | |
| y_train, train_label_filenames_list = load_png_images(path_train_labels) | |
| x_test, test_image_filenames_list = load_png_images(path_test_images) | |
| y_test, test_label_filenames_list = load_png_images(path_test_labels) | |
| x_train = normalize(x_train) | |
| y_train = normalize(y_train) | |
| x_test = normalize(x_test) | |
| y_test = normalize(y_test) | |
| return x_train, y_train, x_test, y_test, test_label_filenames_list | |
| def load_test_images(path): | |
| path_test_images = path + "test/images/" | |
| x_test, test_image_filenames_list = load_png_images(path_test_images) | |
| x_test = normalize(x_test) | |
| return x_test, test_image_filenames_list | |
| def save_results(np_array, color_space, outpath, test_label_filenames_list): | |
| i = 0 | |
| for filename in test_label_filenames_list: | |
| # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1])) | |
| pred = np_array[i] | |
| # if color_space.lower() is 'hsi' or 'hsv': | |
| # pred = cv2.cvtColor(pred, cv2.COLOR_HSV2RGB) | |
| # elif color_space.lower() is 'lab': | |
| # pred = cv2.cvtColor(pred, cv2.COLOR_Lab2RGB) | |
| cv2.imwrite(outpath + filename, pred * 255.) | |
| i += 1 | |
| def save_rgb_results(np_array, outpath, test_label_filenames_list): | |
| i = 0 | |
| for filename in test_label_filenames_list: | |
| # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1])) | |
| cv2.imwrite(outpath + filename, np_array[i] * 255.) | |
| i += 1 | |
| def save_history(model, model_name, training_history, dataset, n_filters, epoch, learning_rate, loss, | |
| color_space, path=None, temp_name=None): | |
| save_weight_filename = temp_name if temp_name else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| model.save('{}{}.hdf5'.format(path, save_weight_filename)) | |
| with open('{}{}.json'.format(path, save_weight_filename), 'w') as f: | |
| json.dump(training_history.history, f, indent=2) | |
| json_list = ['{}{}.json'.format(path, save_weight_filename)] | |
| for json_filename in json_list: | |
| with open(json_filename) as f: | |
| # convert the loss json object to a python dict | |
| loss_dict = json.load(f) | |
| print_list = ['loss', 'val_loss', 'dice_coef', 'val_dice_coef'] | |
| for item in print_list: | |
| item_list = [] | |
| if item in loss_dict: | |
| item_list.extend(loss_dict.get(item)) | |
| plt.plot(item_list) | |
| plt.title('model:{} lr:{} epoch:{} #filtr:{} Colorspaces:{}'.format(model_name, learning_rate, | |
| epoch, n_filters, color_space)) | |
| plt.ylabel('loss') | |
| plt.xlabel('epoch') | |
| plt.legend(['train_loss', 'test_loss', 'train_dice', 'test_dice'], loc='upper left') | |
| plt.savefig('{}{}.png'.format(path, save_weight_filename)) | |
| plt.show() | |
| plt.clf() | |