Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Wed Apr 24 13:57:44 2024 | |
| @author: tjy | |
| """ | |
| import configparser | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| import sys | |
| import os | |
| STEP3_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| SRC_DIR = os.path.join(STEP3_DIR, "Src") | |
| if SRC_DIR not in sys.path: | |
| sys.path.insert(0, SRC_DIR) | |
| from common.common_util import pre_processing, simple_nms, remove_borders, \ | |
| sample_keypoint_desc | |
| from model.super_retina import SuperRetina | |
| from PIL import Image | |
| import re | |
| import shutil | |
| import pandas as pd | |
| from sklearn.linear_model import LinearRegression #import linear regression module | |
| from sklearn.preprocessing import PolynomialFeatures | |
| class Predictor: | |
| def __init__(self, config, img_path=None, savepath=None): | |
| predict_config = config['PREDICT'] | |
| device = predict_config['device'] | |
| device = torch.device(device if torch.cuda.is_available() else "cpu") | |
| model_save_path = predict_config['model_save_path'] | |
| self.nms_size = predict_config['nms_size'] | |
| self.nms_thresh = predict_config['nms_thresh'] | |
| self.scale = 8 | |
| self.knn_thresh = predict_config['knn_thresh'] | |
| self.image_width = None | |
| self.image_height = None | |
| self.model_image_width = predict_config['model_image_width'] | |
| self.model_image_height = predict_config['model_image_height'] | |
| checkpoint = torch.load(model_save_path, map_location=device) | |
| model = SuperRetina() | |
| model.load_state_dict(checkpoint['net']) | |
| model.to(device) | |
| model.eval() | |
| self.device = device | |
| self.model = model | |
| self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2) | |
| self.trasformer = transforms.Compose([ | |
| transforms.Resize((self.model_image_height, self.model_image_width)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.crop = None | |
| self.refer_crop = None | |
| self.query_crop = None | |
| self.crop_path = None | |
| self.query_crop_path = None | |
| self.crop_mode = "refer_box_only" | |
| self.savepath = savepath | |
| self.name = None | |
| self.img_path = img_path | |
| self.last_match_debug = {} | |
| def _load_crop_box(self, crop_path, name): | |
| crop_size = np.asarray(np.loadtxt(crop_path), dtype=np.float32) | |
| crop_size = np.atleast_2d(crop_size) | |
| if crop_size.shape != (2, 2): | |
| raise ValueError("{} invalid format: {}".format(name, crop_path)) | |
| crop_size = np.rint(crop_size).astype(np.int32) | |
| crop_size[:, 0] = np.clip(crop_size[:, 0], 0, 1000) | |
| crop_size[:, 1] = np.clip(crop_size[:, 1], 0, 1000) | |
| if crop_size[1][0] <= crop_size[0][0] or crop_size[1][1] <= crop_size[0][1]: | |
| raise ValueError("{} invalid: {}".format(name, crop_size.tolist())) | |
| return crop_size | |
| def image_read(self, query_path, refer_path, query_is_image=False, crop=False): | |
| if query_is_image: | |
| query_image = query_path | |
| else: | |
| query_image = cv2.imread(query_path, cv2.IMREAD_COLOR) | |
| # green channel | |
| query_image = query_image[:, :, 1] | |
| query_image = pre_processing(query_image) | |
| refer_image = cv2.imread(refer_path, cv2.IMREAD_COLOR) | |
| refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR) | |
| query_image = cv2.resize(query_image,(1000,1000),interpolation=cv2.INTER_LINEAR) | |
| self.query_crop = None | |
| self.refer_crop = None | |
| self.crop = None | |
| if crop: | |
| crop_mode = getattr(self, "crop_mode", "refer_box_only") | |
| if self.crop_path is not None and os.path.exists(self.crop_path): | |
| refer_crop_size = self._load_crop_box(self.crop_path, "refer_crop_box_1000") | |
| else: | |
| cfp_od_fovea = np.loadtxt(refer_path[:-11] + '.txt') | |
| octa_od_fovea = np.loadtxt(query_path[:-11] + '.txt') | |
| octa_len = octa_od_fovea[1][1] - octa_od_fovea[0][1] | |
| if octa_len < 0: | |
| octa_len = -octa_len | |
| cfp_len = cfp_od_fovea[1][1] - cfp_od_fovea[0][1] | |
| if cfp_len < 0: | |
| cfp_len = -cfp_len | |
| rate = cfp_len / octa_len | |
| x0 = max(cfp_od_fovea[1][1] * 1000 - octa_od_fovea[1][1] * 1000 * rate, 0) | |
| x1 = min(cfp_od_fovea[1][1] * 1000 + octa_od_fovea[1][1] * 1000 * rate, 1000) | |
| y0 = max(cfp_od_fovea[1][0] * 1000 - octa_od_fovea[1][0] * 1000 * rate, 0) | |
| y1 = min(cfp_od_fovea[1][0] * 1000 + octa_od_fovea[1][0] * 1000 * rate, 1000) | |
| # Keep behavior consistent with original test_on_OCTA_regression*.py: | |
| # Compute crop using float first, then apply int() truncation at slicing (no round). | |
| refer_crop_size = [[y0, x0], [y1, x1]] | |
| if refer_crop_size[1][0] <= refer_crop_size[0][0] or refer_crop_size[1][1] <= refer_crop_size[0][1]: | |
| raise ValueError("crop_size invalid: {}".format(refer_crop_size)) | |
| refer_image = refer_image[ | |
| int(refer_crop_size[0][0]):int(refer_crop_size[1][0]), | |
| int(refer_crop_size[0][1]):int(refer_crop_size[1][1]), | |
| ] | |
| self.refer_crop = refer_crop_size | |
| self.crop = refer_crop_size | |
| if crop_mode == "dual_crop": | |
| if self.query_crop_path is None or not os.path.exists(self.query_crop_path): | |
| raise ValueError("dual_crop mode is missing a valid query_crop_box_1000 path") | |
| query_crop_size = self._load_crop_box(self.query_crop_path, "query_crop_box_1000") | |
| query_image = query_image[ | |
| int(query_crop_size[0][0]):int(query_crop_size[1][0]), | |
| int(query_crop_size[0][1]):int(query_crop_size[1][1]), | |
| ] | |
| query_image = cv2.resize(query_image, (1000, 1000), interpolation=cv2.INTER_LINEAR) | |
| self.query_crop = query_crop_size | |
| refer_image = refer_image[:, :, 1] | |
| #print(refer_image.shape,query_image.shape) | |
| refer_image = pre_processing(refer_image) | |
| refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR) | |
| assert query_image.shape[:2] == refer_image.shape[:2] | |
| self.image_height, self.image_width = query_image.shape[:2] | |
| refer_image = np.where(refer_image > 0.00, 1 - refer_image, refer_image) | |
| if not query_is_image: | |
| query_image = np.where(query_image > 0.00, 1 - query_image, query_image) | |
| query_image = (query_image * 255).astype(np.uint8) | |
| refer_image = (refer_image * 255).astype(np.uint8) | |
| return query_image, refer_image | |
| def _coordinate_transform(self, keypoints, crop_size): | |
| old_width = crop_size[1][1] - crop_size[0][1] | |
| old_height = crop_size[1][0] - crop_size[0][0] | |
| resize_width = old_width / self.model_image_width | |
| resize_height = old_height / self.model_image_height | |
| return [ | |
| cv2.KeyPoint( | |
| int((i[0] * resize_width + crop_size[0][1]) * self.image_width / 1000), | |
| int((i[1] * resize_height + crop_size[0][0]) * self.image_height / 1000), | |
| 30, | |
| ) | |
| for i in keypoints | |
| ] | |
| def coordinate_transform(self, refer_keypoints): | |
| return self._coordinate_transform(refer_keypoints, self.refer_crop) | |
| def query_coordinate_transform(self, query_keypoints): | |
| return self._coordinate_transform(query_keypoints, self.query_crop) | |
| def _ensure_case_dir(self): | |
| case_dir = os.path.join(self.savepath, self.name) | |
| if not os.path.exists(case_dir): | |
| os.mkdir(case_dir) | |
| return case_dir | |
| def _points_from_matches(self, query_keypoints, refer_keypoints, matches): | |
| if len(matches) == 0: | |
| empty = np.zeros((0, 2), dtype=np.float32) | |
| return empty, empty | |
| query_points = np.array( | |
| [query_keypoints[m.queryIdx].pt for m in matches], | |
| dtype=np.float32, | |
| ) | |
| refer_points = np.array( | |
| [refer_keypoints[m.trainIdx].pt for m in matches], | |
| dtype=np.float32, | |
| ) | |
| return query_points, refer_points | |
| def _save_pair_points(self, save_path, query_points, refer_points, normalize=False): | |
| if len(query_points) == 0: | |
| pairs = np.zeros((0, 4), dtype=np.float64) | |
| else: | |
| pairs = np.column_stack((query_points, refer_points)).astype(np.float64) | |
| if normalize: | |
| pairs = pairs / float(self.image_height) | |
| np.savetxt(save_path, pairs, fmt="%.12f") | |
| return save_path | |
| def _to_bgr(self, image): | |
| if len(image.shape) == 2: | |
| return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| return image.copy() | |
| def _save_match_visualization(self, query_image, refer_image, query_points, refer_points, save_path, title): | |
| query_bgr = self._to_bgr(query_image) | |
| refer_bgr = self._to_bgr(refer_image) | |
| h_a, w_a = query_bgr.shape[:2] | |
| h_b, w_b = refer_bgr.shape[:2] | |
| canvas = np.zeros((max(h_a, h_b), w_a + w_b, 3), dtype=np.uint8) | |
| canvas[0:h_a, 0:w_a] = query_bgr | |
| canvas[0:h_b, w_a:w_a + w_b] = refer_bgr | |
| for query_pt, refer_pt in zip(query_points, refer_points): | |
| pt_a = (int(round(query_pt[0])), int(round(query_pt[1]))) | |
| pt_b = (int(round(refer_pt[0] + w_a)), int(round(refer_pt[1]))) | |
| cv2.line(canvas, pt_a, pt_b, (0, 255, 0), 1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas, pt_a, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas, pt_b, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) | |
| title_height = 70 | |
| vis = np.zeros((canvas.shape[0] + title_height, canvas.shape[1], 3), dtype=np.uint8) | |
| vis[title_height:, :] = canvas | |
| cv2.putText( | |
| vis, | |
| title, | |
| (20, 45), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1.2, | |
| (255, 255, 255), | |
| 2, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| cv2.imwrite(save_path, vis) | |
| return save_path | |
| def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status, crop=False): | |
| def drawMatches(imageA, imageB, kpsA, kpsB, matches, status): | |
| # initialize the output visualization image | |
| (hA, wA) = imageA.shape[:2] | |
| (hB, wB) = imageB.shape[:2] | |
| #hA += 500 | |
| vis = np.zeros((max(hA,hB), wA+wB, 3), dtype="uint8") | |
| #print(imageA.shape) | |
| if len(imageA.shape) == 2: | |
| imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB) | |
| imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB) | |
| vis[0:hA, 0:wA] = imageA | |
| vis[0:hB, wA:] = imageB | |
| pt = [] | |
| # loop over the matches | |
| for (match, _), s in zip(matches, status): | |
| trainIdx, queryIdx = match.trainIdx, match.queryIdx | |
| # only process the match if the keypoint was successfully | |
| # matched | |
| if s == 1: | |
| # draw the match | |
| ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1])) | |
| ptB = (int(kpsB[trainIdx].pt[0])+wA, int(kpsB[trainIdx].pt[1])) | |
| #save matched points | |
| pt.append([kpsA[queryIdx].pt[0],kpsA[queryIdx].pt[1],kpsB[trainIdx].pt[0],kpsB[trainIdx].pt[1]]) | |
| cv2.line(vis, ptA, ptB, (0, 255, 0), 2) | |
| pt = np.array(pt)/hA | |
| np.savetxt(os.path.join(self.savepath,self.name,'pairs_kps.txt'),pt) | |
| # return the visualization | |
| return vis | |
| query_np = np.array([kp.pt for kp in cv_kpts_query]) | |
| refer_np = np.array([kp.pt for kp in cv_kpts_refer]) | |
| refer_np[:, 0] += query_image.shape[1] | |
| #show on raw images | |
| matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status) | |
| fig = plt.figure(dpi=300, facecolor='black') | |
| ax = fig.add_subplot(111) | |
| ax.set_facecolor('black') | |
| ax.scatter(query_np[:, 0], query_np[:, 1], s=0.1, c='r') | |
| ax.scatter(refer_np[:, 0], refer_np[:, 1], s=0.02, c='r') | |
| ax.axis('off') | |
| ax.set_title('Match Result, #goodMatch: {}'.format(status.sum()), color='white') | |
| ax.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB)) | |
| fig.savefig( | |
| os.path.join(self.savepath, self.name, 'paired_kps.png'), | |
| dpi=300, | |
| facecolor=fig.get_facecolor(), | |
| bbox_inches='tight', | |
| pad_inches=0, | |
| ) | |
| plt.show() | |
| plt.close(fig) | |
| def model_run_pair(self, query_tensor, refer_tensor): | |
| inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0))) | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| detector_pred, descriptor_pred = self.model(inputs) | |
| scores = simple_nms(detector_pred, self.nms_size) | |
| b, _, h, w = detector_pred.shape | |
| scores = scores.reshape(-1, h, w) | |
| keypoints = [ | |
| torch.nonzero(s > self.nms_thresh) | |
| for s in scores] | |
| scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] | |
| # Discard keypoints near the image borders | |
| keypoints, scores = list(zip(*[ | |
| remove_borders(k, s, 4, h, w) | |
| for k, s in zip(keypoints, scores)])) | |
| keypoints = [torch.flip(k, [1]).float().data for k in keypoints] | |
| #print(query_tensor.shape) | |
| descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() | |
| for k, d in zip(keypoints, descriptor_pred)] | |
| keypoints = [k.cpu() for k in keypoints] | |
| #print(keypoints[0].shape,keypoints[1].shape) | |
| Data_type = object | |
| keypoints = np.array(keypoints,dtype=Data_type) | |
| #print(keypoints/query_tensor.shape[1]) | |
| #torch.save([keypoints/query_tensor.shape[1], descriptors, scores],'./OCTA_test.pt') | |
| return keypoints, descriptors | |
| def match_scores(self, H_m, size, ngoodmatch): | |
| if ngoodmatch < 15: | |
| return False | |
| points = [[0,0],[0,size],[size,0],[size,size]] | |
| new_pts = [] | |
| for p in points: | |
| x11 = (p[0]*H_m[0][0] + p[1]*H_m[0][1] + H_m[0][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2]) | |
| y11 = (p[0]*H_m[1][0] + p[1]*H_m[1][1] + H_m[1][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2]) | |
| if x11 > size or y11 > size: | |
| return False | |
| new_pts.append([x11,y11]) | |
| if new_pts[1][1] < new_pts[0][1] or new_pts[2][0] < new_pts[0][0] or new_pts[3][0] < new_pts[1][0] or new_pts[3][1] < new_pts[2][1]: | |
| return False | |
| return True | |
| def match(self, query_path, refer_path, show=False, query_is_image=False, crop=False, save_match_images=True): | |
| query_image, refer_image = self.image_read(query_path, refer_path, query_is_image, crop) | |
| query_tensor = self.trasformer(Image.fromarray(query_image)) | |
| refer_tensor = self.trasformer(Image.fromarray(refer_image)) | |
| keypoints, descriptors = self.model_run_pair(query_tensor, refer_tensor) | |
| #print(keypoints[2]) | |
| query_keypoints, refer_keypoints = keypoints[0], keypoints[1] | |
| query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy() | |
| #print(query_keypoints.shape) | |
| # mapping keypoints to scaled keypoints | |
| ################################################## | |
| cv_kpts_query_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), | |
| int(i[1] / self.model_image_height * self.image_height), 30) | |
| for i in query_keypoints] | |
| cv_kpts_refer_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), | |
| int(i[1] / self.model_image_height * self.image_height), 30) | |
| for i in refer_keypoints] | |
| query_image_input = query_image.copy() | |
| refer_image_input = refer_image.copy() | |
| cv_kpts_query = cv_kpts_query_input | |
| cv_kpts_refer = cv_kpts_refer_input | |
| query_image_restored = query_image_input | |
| refer_image_restored = refer_image_input | |
| if crop: | |
| crop_mode = getattr(self, "crop_mode", "refer_box_only") | |
| if crop_mode == "dual_crop": | |
| cv_kpts_query = self.query_coordinate_transform(query_keypoints) | |
| cv_kpts_refer = self.coordinate_transform(refer_keypoints) | |
| query_image_restored, refer_image_restored = self.image_read(query_path, refer_path) | |
| goodMatch = [] | |
| status = [] | |
| matches = [] | |
| try: | |
| matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) | |
| #print(matches.shape) | |
| for m, n in matches: | |
| if m.distance < self.knn_thresh * n.distance: | |
| goodMatch.append(m) | |
| status.append(True) | |
| else: | |
| status.append(False) | |
| except Exception: | |
| pass | |
| case_dir = self._ensure_case_dir() | |
| query_points_input, refer_points_input = self._points_from_matches( | |
| cv_kpts_query_input, | |
| cv_kpts_refer_input, | |
| goodMatch, | |
| ) | |
| query_points_restored, refer_points_restored = self._points_from_matches( | |
| cv_kpts_query, | |
| cv_kpts_refer, | |
| goodMatch, | |
| ) | |
| if save_match_images: | |
| cv2.imwrite(os.path.join(case_dir, "match_query_input.png"), query_image_input) | |
| cv2.imwrite(os.path.join(case_dir, "match_refer_input.png"), refer_image_input) | |
| cv2.imwrite(os.path.join(case_dir, "match_query_restored.png"), query_image_restored) | |
| cv2.imwrite(os.path.join(case_dir, "match_refer_restored.png"), refer_image_restored) | |
| self._save_pair_points( | |
| os.path.join(case_dir, "pairs_kps_input_space_pixels.txt"), | |
| query_points_input, | |
| refer_points_input, | |
| normalize=False, | |
| ) | |
| self._save_pair_points( | |
| os.path.join(case_dir, "pairs_kps_restored_space_pixels.txt"), | |
| query_points_restored, | |
| refer_points_restored, | |
| normalize=False, | |
| ) | |
| self._save_pair_points( | |
| os.path.join(case_dir, "pairs_kps.txt"), | |
| query_points_restored, | |
| refer_points_restored, | |
| normalize=True, | |
| ) | |
| self.last_match_debug = { | |
| "query_image_input": query_image_input, | |
| "refer_image_input": refer_image_input, | |
| "query_image_restored": query_image_restored, | |
| "refer_image_restored": refer_image_restored, | |
| "query_points_input": query_points_input, | |
| "refer_points_input": refer_points_input, | |
| "query_points_restored": query_points_restored, | |
| "refer_points_restored": refer_points_restored, | |
| "good_match_count": int(len(goodMatch)), | |
| } | |
| if show: | |
| self._save_match_visualization( | |
| query_image_input, | |
| refer_image_input, | |
| query_points_input, | |
| refer_points_input, | |
| os.path.join(case_dir, "descriptor_matches_input_space.png"), | |
| "Descriptor Match (Input Space), #goodMatch: {}".format(len(goodMatch)), | |
| ) | |
| self._save_match_visualization( | |
| query_image_restored, | |
| refer_image_restored, | |
| query_points_restored, | |
| refer_points_restored, | |
| os.path.join(case_dir, "paired_kps.png"), | |
| "Descriptor Match (Restored Space), #goodMatch: {}".format(len(goodMatch)), | |
| ) | |
| return goodMatch, cv_kpts_query, cv_kpts_refer, query_image_restored, refer_image_restored | |
| def compute_homography(self, query_path, refer_path, query_is_image=False, show=False, crop=False): | |
| goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \ | |
| self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop) | |
| H_m = None | |
| inliers_num_rate = 0 | |
| if len(goodMatch) >= 4: | |
| src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] | |
| src_pts = np.float32(src_pts).reshape(-1, 1, 2) | |
| dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] | |
| dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) | |
| #print(src_pts) | |
| H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) | |
| # src_pts = src_pts[mask.ravel() == 1] | |
| # dst_pts = dst_pts[mask.ravel() == 1] | |
| goodMatch = np.array(goodMatch)[mask.ravel() == 1] | |
| inliers_num_rate = mask.sum() / len(mask.ravel()) | |
| #print(mask.sum(),len(mask.ravel())) | |
| flag = True | |
| #if not self.match_scores(H_m, 1000, len(goodMatch)): | |
| #flag = False | |
| return H_m, inliers_num_rate, raw_query_image, raw_refer_image, flag | |
| def regression(self, query_path, refer_path, query_is_image=False, show=False, crop=False): | |
| goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \ | |
| self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop) | |
| if len(goodMatch) >= 4: | |
| src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] | |
| src_pts = np.float32(src_pts).reshape(-1, 1, 2) | |
| dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] | |
| dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) | |
| print(len(goodMatch)) | |
| H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=10.0) | |
| if mask is None: | |
| return None, raw_query_image, raw_refer_image, False | |
| # src_pts = src_pts[mask.ravel() == 1] | |
| # dst_pts = dst_pts[mask.ravel() == 1] | |
| inlier_mask = mask.ravel().astype(bool) | |
| goodMatch = np.array(goodMatch)[inlier_mask] | |
| print(len(goodMatch)) | |
| case_dir = self._ensure_case_dir() | |
| if self.last_match_debug and len(self.last_match_debug.get("query_points_restored", [])) == len(inlier_mask): | |
| query_points_input = self.last_match_debug["query_points_input"][inlier_mask] | |
| refer_points_input = self.last_match_debug["refer_points_input"][inlier_mask] | |
| query_points_restored = self.last_match_debug["query_points_restored"][inlier_mask] | |
| refer_points_restored = self.last_match_debug["refer_points_restored"][inlier_mask] | |
| self._save_pair_points( | |
| os.path.join(case_dir, "ransac_inlier_pairs_input_space_pixels.txt"), | |
| query_points_input, | |
| refer_points_input, | |
| normalize=False, | |
| ) | |
| self._save_pair_points( | |
| os.path.join(case_dir, "ransac_inlier_pairs_restored_space_pixels.txt"), | |
| query_points_restored, | |
| refer_points_restored, | |
| normalize=False, | |
| ) | |
| self._save_match_visualization( | |
| self.last_match_debug["query_image_input"], | |
| self.last_match_debug["refer_image_input"], | |
| query_points_input, | |
| refer_points_input, | |
| os.path.join(case_dir, "ransac_inlier_pairs_input_space.png"), | |
| "RANSAC Inlier Match (Input Space), #inlier: {}".format(len(query_points_input)), | |
| ) | |
| self._save_match_visualization( | |
| self.last_match_debug["query_image_restored"], | |
| self.last_match_debug["refer_image_restored"], | |
| query_points_restored, | |
| refer_points_restored, | |
| os.path.join(case_dir, "ransac_inlier_pairs_restored_space.png"), | |
| "RANSAC Inlier Match (Restored Space), #inlier: {}".format(len(query_points_restored)), | |
| ) | |
| if len(goodMatch) < 4: | |
| return None, raw_query_image, raw_refer_image, False | |
| src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] | |
| src_pts = np.float32(src_pts).reshape(-1, 1, 2) | |
| dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] | |
| dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) | |
| else: | |
| return None, raw_query_image, raw_refer_image, False | |
| src_pts = src_pts[:,0] | |
| dst_pts = dst_pts[:,0] | |
| src_pts = src_pts.tolist() | |
| for index in range(2,3): | |
| #print(src_pts,dst_pts) | |
| data_X = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,0]}) | |
| data_train_X = np.array(data_X['IN']).reshape(data_X['IN'].shape[0],1) | |
| data_test_X = data_X['OUT'] | |
| poly_reg_X = PolynomialFeatures(degree = index) | |
| X_ploy = poly_reg_X.fit_transform(src_pts) | |
| regr_X = LinearRegression() | |
| regr_X.fit(X_ploy,data_test_X) | |
| if(regr_X.score(X_ploy,data_test_X) >= 0.99): | |
| break | |
| for index in range(2,3): | |
| data_Y = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,1]}) | |
| data_train_Y = np.array(data_Y['IN']).reshape(data_Y['IN'].shape[0],1) | |
| data_test_Y = data_Y['OUT'] | |
| poly_reg_Y = PolynomialFeatures(degree = index) | |
| Y_ploy = poly_reg_Y.fit_transform(src_pts) | |
| regr_Y = LinearRegression() | |
| regr_Y.fit(Y_ploy,data_test_Y) | |
| if(regr_Y.score(Y_ploy,data_test_Y) >= 0.99): | |
| break | |
| raw = [[x,y] for x in range(1000) for y in range(1000)] | |
| #raw = src_pts | |
| #raw = [[0,1],[2,40],[5,500]] | |
| raw = np.array(raw) | |
| #print(raw) | |
| new_data_poly_X = poly_reg_X.fit_transform(raw) | |
| #print(new_data_poly_X) | |
| predicted_values_X = regr_X.predict(new_data_poly_X) | |
| new_data_poly_Y = poly_reg_Y.fit_transform(raw) | |
| predicted_values_Y = regr_Y.predict(new_data_poly_Y) | |
| polynomial_matrix = np.column_stack((predicted_values_X, predicted_values_Y)) | |
| #print(polynomial_matrix) | |
| new_poly_matrix = np.array(np.around(polynomial_matrix), dtype=int) | |
| warps = np.zeros((self.image_height,self.image_width), dtype=np.uint8) | |
| for i in range(len(new_poly_matrix)): | |
| if new_poly_matrix[i][0] < 1000 and new_poly_matrix[i][1] < 1000 and new_poly_matrix[i][0] >= 0 and new_poly_matrix[i][1] >= 0: | |
| warps[new_poly_matrix[i][1]][new_poly_matrix[i][0]] = raw_query_image[raw[i][1]][raw[i][0]] | |
| plt.imshow(warps,cmap='gray') | |
| #plt.scatter(new_poly_matrix[:,0], new_poly_matrix[:,1],s=0.1) | |
| plt.show() | |
| plt.close() | |
| #print(mask.sum(),len(mask.ravel())) | |
| flag = True | |
| return warps, raw_query_image, raw_refer_image, flag | |
| def align_image_pair_regression(self, query_path, refer_path, show=False, crop=False): | |
| self.name = os.path.split(refer_path)[-1][:-4] | |
| if not os.path.exists(os.path.join(self.savepath,self.name)): | |
| os.mkdir(os.path.join(self.savepath,self.name)) | |
| #print(self.name) | |
| warps, raw_query_image, raw_refer_image,flag = self.regression(query_path, refer_path, show=show, crop=crop) | |
| if not flag: | |
| name = os.path.split(query_path)[-1] | |
| print(f"Matching Failed for {name}") | |
| shutil.rmtree(os.path.join(self.savepath,self.name)) | |
| return | |
| #print(inliers_num_rate) | |
| if warps is not None: | |
| h, w = self.image_height, self.image_width | |
| merged = np.zeros((h, w, 3), dtype=np.uint8) | |
| if len(raw_refer_image.shape) == 3: | |
| refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| refer_gray = raw_refer_image | |
| merged[:, :, 0] = warps | |
| merged[:, :, 1] = refer_gray | |
| if show: | |
| fig = plt.figure(dpi=300, facecolor='black') | |
| ax = fig.add_subplot(111) | |
| ax.set_facecolor('black') | |
| ax.imshow(merged) | |
| ax.axis('off') | |
| ax.set_title('Registration Result', color='white') | |
| fig.savefig( | |
| os.path.join(self.savepath, self.name, 'merged.png'), | |
| dpi=300, | |
| facecolor=fig.get_facecolor(), | |
| bbox_inches='tight', | |
| pad_inches=0, | |
| ) | |
| plt.show() | |
| plt.close(fig) | |
| return merged | |
| print("Matched Failed!") | |
| def align_image_pair(self, query_path, refer_path, show=False, crop=False): | |
| self.name = os.path.split(refer_path)[-1][:-4] | |
| if not os.path.exists(os.path.join(self.savepath,self.name)): | |
| os.mkdir(os.path.join(self.savepath,self.name)) | |
| #print(self.name) | |
| H_m, inliers_num_rate, raw_query_image, raw_refer_image,flag = self.compute_homography(query_path, refer_path, show=show, crop=crop) | |
| if not flag: | |
| name = os.path.split(query_path)[-1] | |
| print(f"Matching Failed for {name}") | |
| shutil.rmtree(os.path.join(self.savepath,self.name)) | |
| return | |
| #print(inliers_num_rate) | |
| if H_m is not None: | |
| h, w = self.image_height, self.image_width | |
| query_align = cv2.warpPerspective(raw_query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT, | |
| borderValue=(0)) | |
| merged = np.zeros((h, w, 3), dtype=np.uint8) | |
| if len(query_align.shape) == 3: | |
| query_align = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY) | |
| if len(raw_refer_image.shape) == 3: | |
| refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| refer_gray = raw_refer_image | |
| merged[:, :, 0] = query_align | |
| merged[:, :, 1] = refer_gray | |
| if True: | |
| fig = plt.figure(dpi=300, facecolor='black') | |
| ax = fig.add_subplot(111) | |
| ax.set_facecolor('black') | |
| ax.imshow(merged) | |
| ax.axis('off') | |
| ax.set_title('Registration Result', color='white') | |
| plt.show() | |
| plt.close(fig) | |
| return merged | |
| print("Matched Failed!") | |
| def model_run_one_image(self, image_path, save_path=None): | |
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
| image = image[:, :, 1] | |
| self.image_height, self.image_width = image.shape[:2] | |
| image = pre_processing(image) | |
| image_tensor = self.trasformer(Image.fromarray(image)) | |
| inputs = image_tensor.unsqueeze(0) | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| detector_pred, descriptor_pred = self.model(inputs) | |
| scores = simple_nms(detector_pred, self.nms_size) | |
| b, _, h, w = detector_pred.shape | |
| scores = scores.reshape(-1, h, w) | |
| keypoints = [ | |
| torch.nonzero(s > self.nms_thresh) | |
| for s in scores] | |
| scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] | |
| # Discard keypoints near the image borders | |
| keypoints, scores = list(zip(*[ | |
| remove_borders(k, s, 4, h, w) | |
| for k, s in zip(keypoints, scores)])) | |
| keypoints = [torch.flip(k, [1]).float().data for k in keypoints] | |
| descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() | |
| for k, d in zip(keypoints, descriptor_pred)] | |
| keypoints = [k.cpu() for k in keypoints] | |
| if save_path is not None: | |
| save_info = {'kp': keypoints[0].cpu()/self.model_image_width, 'desc': descriptors[0].cpu(), 'scores': scores} | |
| torch.save(save_info, save_path) | |
| print("Successfully saved! ",keypoints[0].shape) | |
| return keypoints[0], descriptors[0] | |
| def homography_from_tensor(self, query_info, refer_info): | |
| query_keypoints, query_desc = query_info['kp'], query_info['desc'] | |
| refer_keypoints, refer_desc = refer_info['kp'], refer_info['desc'] | |
| query_desc = query_desc.permute(1, 0).numpy() | |
| refer_desc = refer_desc.permute(1, 0).numpy() | |
| cv_kpts_query = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), | |
| int(i[1] / self.model_image_height * self.image_height), 30) | |
| for i in query_keypoints] | |
| cv_kpts_refer = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), | |
| int(i[1] / self.model_image_height * self.image_height), 30) | |
| for i in refer_keypoints] | |
| goodMatch = [] | |
| status = [] | |
| try: | |
| matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) | |
| for m, n in matches: | |
| if m.distance < self.knn_thresh * n.distance: | |
| goodMatch.append(m) | |
| status.append(True) | |
| else: | |
| status.append(False) | |
| except Exception: | |
| pass | |
| H_m = None | |
| inliers_num = 0 | |
| if len(goodMatch) >= 4: | |
| src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] | |
| src_pts = np.float32(src_pts).reshape(-1, 1, 2) | |
| dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] | |
| dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) | |
| H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) | |
| # src_pts = src_pts[mask.ravel() == 1] | |
| # dst_pts = dst_pts[mask.ravel() == 1] | |
| goodMatch = np.array(goodMatch)[mask.ravel() == 1] | |
| inliers_num = mask.sum() | |
| return H_m, inliers_num | |
| if __name__ == '__main__': | |
| import yaml | |
| config_path = 'config/test.yaml' | |
| if os.path.exists(config_path): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| else: | |
| raise FileNotFoundError("Config File doesn't Exist") | |
| savepath = './data/bad_images/result' | |
| f2 = r'./data/Crop_vessel_test20/Images/1001_Z2203_a0.jpg' | |
| f1 = r'./data/Crop_vessel_test20/Images/1001_Z2203_c1.jpg' | |
| f1 = './data/FIRE/Images/P35_1.jpg' | |
| f2 = './data/FIRE/Images/P35_2.jpg' | |
| #f1 = './data/samples/1007_Z2204_a2.jpg' | |
| #f2 = './data/samples/1007_Z2204_c0.jpg' | |
| #f2 = './data/OCTA-1/Images/8_OD_FP_vessel.jpg' | |
| #f1 = './data/OCTA-1/OCTA/8_OD_OCTA_vessel.jpg' | |
| #f1 = './data/bad_images/octa/shandong-mm-phase1-00170-20230901-OS_001_vessel.jpg' | |
| #f2 = './data/bad_images/cfp/shandong-mm-phase1-00170-20230901-OS.fundus.jpg' | |
| savepath = './data/bad_images/result' | |
| #P.match(f1, f2, show=True, crop=True) | |
| img_path = './data/bad_images/' | |
| P = Predictor(config, img_path, savepath) | |
| P.align_image_pair_regression(f1, f2, crop=False, show=True) | |
| ''' | |
| octa_imgs = os.listdir(os.path.join(img_path,'octa')) | |
| cfp_imgs = os.listdir(os.path.join(img_path,'cfp')) | |
| #print(cfp_imgs) | |
| for img in cfp_imgs: | |
| #if os.path.exists(os.path.join(savepath,img[:-4])): | |
| #print("pass") | |
| #continue | |
| f2 = os.path.join(img_path,'cfp',img) | |
| octa = [i for i in octa_imgs if img[:-11] in i] | |
| f1 = os.path.join(img_path,'octa',octa[0]) | |
| #print(f1,f2) | |
| #try: | |
| P.align_image_pair(f1, f2, crop=True, show=True) | |
| #except Exception as e: | |
| #shutil.rmtree(os.path.join(savepath,img[:-4])) | |
| #print(e)''' | |
| ''' | |
| #export keys, desc | |
| path = './data/OCTA-1/Crop_Images/' | |
| #path = './data/freiburg_sequence/' | |
| files = os.listdir(path) | |
| for file in files: | |
| if '.jpg' in file: | |
| f1 = path + file | |
| P.model_run_one_image(f1, f1[:-4]+'.pt')''' | |