# -*- coding: utf-8 -*- """ Created on Wed Apr 24 13:57:44 2024 @author: tjy """ import configparser import torch import cv2 import numpy as np import matplotlib.pyplot as plt from torchvision import transforms import sys import os STEP3_DIR = os.path.dirname(os.path.abspath(__file__)) SRC_DIR = os.path.join(STEP3_DIR, "Src") if SRC_DIR not in sys.path: sys.path.insert(0, SRC_DIR) from common.common_util import pre_processing, simple_nms, remove_borders, \ sample_keypoint_desc from model.super_retina import SuperRetina from PIL import Image import re import shutil import pandas as pd from sklearn.linear_model import LinearRegression #import linear regression module from sklearn.preprocessing import PolynomialFeatures class Predictor: def __init__(self, config, img_path=None, savepath=None): predict_config = config['PREDICT'] device = predict_config['device'] device = torch.device(device if torch.cuda.is_available() else "cpu") model_save_path = predict_config['model_save_path'] self.nms_size = predict_config['nms_size'] self.nms_thresh = predict_config['nms_thresh'] self.scale = 8 self.knn_thresh = predict_config['knn_thresh'] self.image_width = None self.image_height = None self.model_image_width = predict_config['model_image_width'] self.model_image_height = predict_config['model_image_height'] checkpoint = torch.load(model_save_path, map_location=device) model = SuperRetina() model.load_state_dict(checkpoint['net']) model.to(device) model.eval() self.device = device self.model = model self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2) self.trasformer = transforms.Compose([ transforms.Resize((self.model_image_height, self.model_image_width)), transforms.ToTensor(), ]) self.crop = None self.refer_crop = None self.query_crop = None self.crop_path = None self.query_crop_path = None self.crop_mode = "refer_box_only" self.savepath = savepath self.name = None self.img_path = img_path self.last_match_debug = {} def _load_crop_box(self, crop_path, name): crop_size = np.asarray(np.loadtxt(crop_path), dtype=np.float32) crop_size = np.atleast_2d(crop_size) if crop_size.shape != (2, 2): raise ValueError("{} invalid format: {}".format(name, crop_path)) crop_size = np.rint(crop_size).astype(np.int32) crop_size[:, 0] = np.clip(crop_size[:, 0], 0, 1000) crop_size[:, 1] = np.clip(crop_size[:, 1], 0, 1000) if crop_size[1][0] <= crop_size[0][0] or crop_size[1][1] <= crop_size[0][1]: raise ValueError("{} invalid: {}".format(name, crop_size.tolist())) return crop_size def image_read(self, query_path, refer_path, query_is_image=False, crop=False): if query_is_image: query_image = query_path else: query_image = cv2.imread(query_path, cv2.IMREAD_COLOR) # green channel query_image = query_image[:, :, 1] query_image = pre_processing(query_image) refer_image = cv2.imread(refer_path, cv2.IMREAD_COLOR) refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR) query_image = cv2.resize(query_image,(1000,1000),interpolation=cv2.INTER_LINEAR) self.query_crop = None self.refer_crop = None self.crop = None if crop: crop_mode = getattr(self, "crop_mode", "refer_box_only") if self.crop_path is not None and os.path.exists(self.crop_path): refer_crop_size = self._load_crop_box(self.crop_path, "refer_crop_box_1000") else: cfp_od_fovea = np.loadtxt(refer_path[:-11] + '.txt') octa_od_fovea = np.loadtxt(query_path[:-11] + '.txt') octa_len = octa_od_fovea[1][1] - octa_od_fovea[0][1] if octa_len < 0: octa_len = -octa_len cfp_len = cfp_od_fovea[1][1] - cfp_od_fovea[0][1] if cfp_len < 0: cfp_len = -cfp_len rate = cfp_len / octa_len x0 = max(cfp_od_fovea[1][1] * 1000 - octa_od_fovea[1][1] * 1000 * rate, 0) x1 = min(cfp_od_fovea[1][1] * 1000 + octa_od_fovea[1][1] * 1000 * rate, 1000) y0 = max(cfp_od_fovea[1][0] * 1000 - octa_od_fovea[1][0] * 1000 * rate, 0) y1 = min(cfp_od_fovea[1][0] * 1000 + octa_od_fovea[1][0] * 1000 * rate, 1000) # Keep behavior consistent with original test_on_OCTA_regression*.py: # Compute crop using float first, then apply int() truncation at slicing (no round). refer_crop_size = [[y0, x0], [y1, x1]] if refer_crop_size[1][0] <= refer_crop_size[0][0] or refer_crop_size[1][1] <= refer_crop_size[0][1]: raise ValueError("crop_size invalid: {}".format(refer_crop_size)) refer_image = refer_image[ int(refer_crop_size[0][0]):int(refer_crop_size[1][0]), int(refer_crop_size[0][1]):int(refer_crop_size[1][1]), ] self.refer_crop = refer_crop_size self.crop = refer_crop_size if crop_mode == "dual_crop": if self.query_crop_path is None or not os.path.exists(self.query_crop_path): raise ValueError("dual_crop mode is missing a valid query_crop_box_1000 path") query_crop_size = self._load_crop_box(self.query_crop_path, "query_crop_box_1000") query_image = query_image[ int(query_crop_size[0][0]):int(query_crop_size[1][0]), int(query_crop_size[0][1]):int(query_crop_size[1][1]), ] query_image = cv2.resize(query_image, (1000, 1000), interpolation=cv2.INTER_LINEAR) self.query_crop = query_crop_size refer_image = refer_image[:, :, 1] #print(refer_image.shape,query_image.shape) refer_image = pre_processing(refer_image) refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR) assert query_image.shape[:2] == refer_image.shape[:2] self.image_height, self.image_width = query_image.shape[:2] refer_image = np.where(refer_image > 0.00, 1 - refer_image, refer_image) if not query_is_image: query_image = np.where(query_image > 0.00, 1 - query_image, query_image) query_image = (query_image * 255).astype(np.uint8) refer_image = (refer_image * 255).astype(np.uint8) return query_image, refer_image def _coordinate_transform(self, keypoints, crop_size): old_width = crop_size[1][1] - crop_size[0][1] old_height = crop_size[1][0] - crop_size[0][0] resize_width = old_width / self.model_image_width resize_height = old_height / self.model_image_height return [ cv2.KeyPoint( int((i[0] * resize_width + crop_size[0][1]) * self.image_width / 1000), int((i[1] * resize_height + crop_size[0][0]) * self.image_height / 1000), 30, ) for i in keypoints ] def coordinate_transform(self, refer_keypoints): return self._coordinate_transform(refer_keypoints, self.refer_crop) def query_coordinate_transform(self, query_keypoints): return self._coordinate_transform(query_keypoints, self.query_crop) def _ensure_case_dir(self): case_dir = os.path.join(self.savepath, self.name) if not os.path.exists(case_dir): os.mkdir(case_dir) return case_dir def _points_from_matches(self, query_keypoints, refer_keypoints, matches): if len(matches) == 0: empty = np.zeros((0, 2), dtype=np.float32) return empty, empty query_points = np.array( [query_keypoints[m.queryIdx].pt for m in matches], dtype=np.float32, ) refer_points = np.array( [refer_keypoints[m.trainIdx].pt for m in matches], dtype=np.float32, ) return query_points, refer_points def _save_pair_points(self, save_path, query_points, refer_points, normalize=False): if len(query_points) == 0: pairs = np.zeros((0, 4), dtype=np.float64) else: pairs = np.column_stack((query_points, refer_points)).astype(np.float64) if normalize: pairs = pairs / float(self.image_height) np.savetxt(save_path, pairs, fmt="%.12f") return save_path def _to_bgr(self, image): if len(image.shape) == 2: return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) return image.copy() def _save_match_visualization(self, query_image, refer_image, query_points, refer_points, save_path, title): query_bgr = self._to_bgr(query_image) refer_bgr = self._to_bgr(refer_image) h_a, w_a = query_bgr.shape[:2] h_b, w_b = refer_bgr.shape[:2] canvas = np.zeros((max(h_a, h_b), w_a + w_b, 3), dtype=np.uint8) canvas[0:h_a, 0:w_a] = query_bgr canvas[0:h_b, w_a:w_a + w_b] = refer_bgr for query_pt, refer_pt in zip(query_points, refer_points): pt_a = (int(round(query_pt[0])), int(round(query_pt[1]))) pt_b = (int(round(refer_pt[0] + w_a)), int(round(refer_pt[1]))) cv2.line(canvas, pt_a, pt_b, (0, 255, 0), 1, lineType=cv2.LINE_AA) cv2.circle(canvas, pt_a, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) cv2.circle(canvas, pt_b, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) title_height = 70 vis = np.zeros((canvas.shape[0] + title_height, canvas.shape[1], 3), dtype=np.uint8) vis[title_height:, :] = canvas cv2.putText( vis, title, (20, 45), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 2, lineType=cv2.LINE_AA, ) cv2.imwrite(save_path, vis) return save_path def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status, crop=False): def drawMatches(imageA, imageB, kpsA, kpsB, matches, status): # initialize the output visualization image (hA, wA) = imageA.shape[:2] (hB, wB) = imageB.shape[:2] #hA += 500 vis = np.zeros((max(hA,hB), wA+wB, 3), dtype="uint8") #print(imageA.shape) if len(imageA.shape) == 2: imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB) imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB) vis[0:hA, 0:wA] = imageA vis[0:hB, wA:] = imageB pt = [] # loop over the matches for (match, _), s in zip(matches, status): trainIdx, queryIdx = match.trainIdx, match.queryIdx # only process the match if the keypoint was successfully # matched if s == 1: # draw the match ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1])) ptB = (int(kpsB[trainIdx].pt[0])+wA, int(kpsB[trainIdx].pt[1])) #save matched points pt.append([kpsA[queryIdx].pt[0],kpsA[queryIdx].pt[1],kpsB[trainIdx].pt[0],kpsB[trainIdx].pt[1]]) cv2.line(vis, ptA, ptB, (0, 255, 0), 2) pt = np.array(pt)/hA np.savetxt(os.path.join(self.savepath,self.name,'pairs_kps.txt'),pt) # return the visualization return vis query_np = np.array([kp.pt for kp in cv_kpts_query]) refer_np = np.array([kp.pt for kp in cv_kpts_refer]) refer_np[:, 0] += query_image.shape[1] #show on raw images matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status) fig = plt.figure(dpi=300, facecolor='black') ax = fig.add_subplot(111) ax.set_facecolor('black') ax.scatter(query_np[:, 0], query_np[:, 1], s=0.1, c='r') ax.scatter(refer_np[:, 0], refer_np[:, 1], s=0.02, c='r') ax.axis('off') ax.set_title('Match Result, #goodMatch: {}'.format(status.sum()), color='white') ax.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB)) fig.savefig( os.path.join(self.savepath, self.name, 'paired_kps.png'), dpi=300, facecolor=fig.get_facecolor(), bbox_inches='tight', pad_inches=0, ) plt.show() plt.close(fig) def model_run_pair(self, query_tensor, refer_tensor): inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0))) inputs = inputs.to(self.device) with torch.no_grad(): detector_pred, descriptor_pred = self.model(inputs) scores = simple_nms(detector_pred, self.nms_size) b, _, h, w = detector_pred.shape scores = scores.reshape(-1, h, w) keypoints = [ torch.nonzero(s > self.nms_thresh) for s in scores] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders keypoints, scores = list(zip(*[ remove_borders(k, s, 4, h, w) for k, s in zip(keypoints, scores)])) keypoints = [torch.flip(k, [1]).float().data for k in keypoints] #print(query_tensor.shape) descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() for k, d in zip(keypoints, descriptor_pred)] keypoints = [k.cpu() for k in keypoints] #print(keypoints[0].shape,keypoints[1].shape) Data_type = object keypoints = np.array(keypoints,dtype=Data_type) #print(keypoints/query_tensor.shape[1]) #torch.save([keypoints/query_tensor.shape[1], descriptors, scores],'./OCTA_test.pt') return keypoints, descriptors def match_scores(self, H_m, size, ngoodmatch): if ngoodmatch < 15: return False points = [[0,0],[0,size],[size,0],[size,size]] new_pts = [] for p in points: x11 = (p[0]*H_m[0][0] + p[1]*H_m[0][1] + H_m[0][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2]) y11 = (p[0]*H_m[1][0] + p[1]*H_m[1][1] + H_m[1][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2]) if x11 > size or y11 > size: return False new_pts.append([x11,y11]) if new_pts[1][1] < new_pts[0][1] or new_pts[2][0] < new_pts[0][0] or new_pts[3][0] < new_pts[1][0] or new_pts[3][1] < new_pts[2][1]: return False return True def match(self, query_path, refer_path, show=False, query_is_image=False, crop=False, save_match_images=True): query_image, refer_image = self.image_read(query_path, refer_path, query_is_image, crop) query_tensor = self.trasformer(Image.fromarray(query_image)) refer_tensor = self.trasformer(Image.fromarray(refer_image)) keypoints, descriptors = self.model_run_pair(query_tensor, refer_tensor) #print(keypoints[2]) query_keypoints, refer_keypoints = keypoints[0], keypoints[1] query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy() #print(query_keypoints.shape) # mapping keypoints to scaled keypoints ################################################## cv_kpts_query_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), int(i[1] / self.model_image_height * self.image_height), 30) for i in query_keypoints] cv_kpts_refer_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), int(i[1] / self.model_image_height * self.image_height), 30) for i in refer_keypoints] query_image_input = query_image.copy() refer_image_input = refer_image.copy() cv_kpts_query = cv_kpts_query_input cv_kpts_refer = cv_kpts_refer_input query_image_restored = query_image_input refer_image_restored = refer_image_input if crop: crop_mode = getattr(self, "crop_mode", "refer_box_only") if crop_mode == "dual_crop": cv_kpts_query = self.query_coordinate_transform(query_keypoints) cv_kpts_refer = self.coordinate_transform(refer_keypoints) query_image_restored, refer_image_restored = self.image_read(query_path, refer_path) goodMatch = [] status = [] matches = [] try: matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) #print(matches.shape) for m, n in matches: if m.distance < self.knn_thresh * n.distance: goodMatch.append(m) status.append(True) else: status.append(False) except Exception: pass case_dir = self._ensure_case_dir() query_points_input, refer_points_input = self._points_from_matches( cv_kpts_query_input, cv_kpts_refer_input, goodMatch, ) query_points_restored, refer_points_restored = self._points_from_matches( cv_kpts_query, cv_kpts_refer, goodMatch, ) if save_match_images: cv2.imwrite(os.path.join(case_dir, "match_query_input.png"), query_image_input) cv2.imwrite(os.path.join(case_dir, "match_refer_input.png"), refer_image_input) cv2.imwrite(os.path.join(case_dir, "match_query_restored.png"), query_image_restored) cv2.imwrite(os.path.join(case_dir, "match_refer_restored.png"), refer_image_restored) self._save_pair_points( os.path.join(case_dir, "pairs_kps_input_space_pixels.txt"), query_points_input, refer_points_input, normalize=False, ) self._save_pair_points( os.path.join(case_dir, "pairs_kps_restored_space_pixels.txt"), query_points_restored, refer_points_restored, normalize=False, ) self._save_pair_points( os.path.join(case_dir, "pairs_kps.txt"), query_points_restored, refer_points_restored, normalize=True, ) self.last_match_debug = { "query_image_input": query_image_input, "refer_image_input": refer_image_input, "query_image_restored": query_image_restored, "refer_image_restored": refer_image_restored, "query_points_input": query_points_input, "refer_points_input": refer_points_input, "query_points_restored": query_points_restored, "refer_points_restored": refer_points_restored, "good_match_count": int(len(goodMatch)), } if show: self._save_match_visualization( query_image_input, refer_image_input, query_points_input, refer_points_input, os.path.join(case_dir, "descriptor_matches_input_space.png"), "Descriptor Match (Input Space), #goodMatch: {}".format(len(goodMatch)), ) self._save_match_visualization( query_image_restored, refer_image_restored, query_points_restored, refer_points_restored, os.path.join(case_dir, "paired_kps.png"), "Descriptor Match (Restored Space), #goodMatch: {}".format(len(goodMatch)), ) return goodMatch, cv_kpts_query, cv_kpts_refer, query_image_restored, refer_image_restored def compute_homography(self, query_path, refer_path, query_is_image=False, show=False, crop=False): goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \ self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop) H_m = None inliers_num_rate = 0 if len(goodMatch) >= 4: src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] src_pts = np.float32(src_pts).reshape(-1, 1, 2) dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) #print(src_pts) H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) # src_pts = src_pts[mask.ravel() == 1] # dst_pts = dst_pts[mask.ravel() == 1] goodMatch = np.array(goodMatch)[mask.ravel() == 1] inliers_num_rate = mask.sum() / len(mask.ravel()) #print(mask.sum(),len(mask.ravel())) flag = True #if not self.match_scores(H_m, 1000, len(goodMatch)): #flag = False return H_m, inliers_num_rate, raw_query_image, raw_refer_image, flag def regression(self, query_path, refer_path, query_is_image=False, show=False, crop=False): goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \ self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop) if len(goodMatch) >= 4: src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] src_pts = np.float32(src_pts).reshape(-1, 1, 2) dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) print(len(goodMatch)) H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=10.0) if mask is None: return None, raw_query_image, raw_refer_image, False # src_pts = src_pts[mask.ravel() == 1] # dst_pts = dst_pts[mask.ravel() == 1] inlier_mask = mask.ravel().astype(bool) goodMatch = np.array(goodMatch)[inlier_mask] print(len(goodMatch)) case_dir = self._ensure_case_dir() if self.last_match_debug and len(self.last_match_debug.get("query_points_restored", [])) == len(inlier_mask): query_points_input = self.last_match_debug["query_points_input"][inlier_mask] refer_points_input = self.last_match_debug["refer_points_input"][inlier_mask] query_points_restored = self.last_match_debug["query_points_restored"][inlier_mask] refer_points_restored = self.last_match_debug["refer_points_restored"][inlier_mask] self._save_pair_points( os.path.join(case_dir, "ransac_inlier_pairs_input_space_pixels.txt"), query_points_input, refer_points_input, normalize=False, ) self._save_pair_points( os.path.join(case_dir, "ransac_inlier_pairs_restored_space_pixels.txt"), query_points_restored, refer_points_restored, normalize=False, ) self._save_match_visualization( self.last_match_debug["query_image_input"], self.last_match_debug["refer_image_input"], query_points_input, refer_points_input, os.path.join(case_dir, "ransac_inlier_pairs_input_space.png"), "RANSAC Inlier Match (Input Space), #inlier: {}".format(len(query_points_input)), ) self._save_match_visualization( self.last_match_debug["query_image_restored"], self.last_match_debug["refer_image_restored"], query_points_restored, refer_points_restored, os.path.join(case_dir, "ransac_inlier_pairs_restored_space.png"), "RANSAC Inlier Match (Restored Space), #inlier: {}".format(len(query_points_restored)), ) if len(goodMatch) < 4: return None, raw_query_image, raw_refer_image, False src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] src_pts = np.float32(src_pts).reshape(-1, 1, 2) dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) else: return None, raw_query_image, raw_refer_image, False src_pts = src_pts[:,0] dst_pts = dst_pts[:,0] src_pts = src_pts.tolist() for index in range(2,3): #print(src_pts,dst_pts) data_X = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,0]}) data_train_X = np.array(data_X['IN']).reshape(data_X['IN'].shape[0],1) data_test_X = data_X['OUT'] poly_reg_X = PolynomialFeatures(degree = index) X_ploy = poly_reg_X.fit_transform(src_pts) regr_X = LinearRegression() regr_X.fit(X_ploy,data_test_X) if(regr_X.score(X_ploy,data_test_X) >= 0.99): break for index in range(2,3): data_Y = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,1]}) data_train_Y = np.array(data_Y['IN']).reshape(data_Y['IN'].shape[0],1) data_test_Y = data_Y['OUT'] poly_reg_Y = PolynomialFeatures(degree = index) Y_ploy = poly_reg_Y.fit_transform(src_pts) regr_Y = LinearRegression() regr_Y.fit(Y_ploy,data_test_Y) if(regr_Y.score(Y_ploy,data_test_Y) >= 0.99): break raw = [[x,y] for x in range(1000) for y in range(1000)] #raw = src_pts #raw = [[0,1],[2,40],[5,500]] raw = np.array(raw) #print(raw) new_data_poly_X = poly_reg_X.fit_transform(raw) #print(new_data_poly_X) predicted_values_X = regr_X.predict(new_data_poly_X) new_data_poly_Y = poly_reg_Y.fit_transform(raw) predicted_values_Y = regr_Y.predict(new_data_poly_Y) polynomial_matrix = np.column_stack((predicted_values_X, predicted_values_Y)) #print(polynomial_matrix) new_poly_matrix = np.array(np.around(polynomial_matrix), dtype=int) warps = np.zeros((self.image_height,self.image_width), dtype=np.uint8) for i in range(len(new_poly_matrix)): if new_poly_matrix[i][0] < 1000 and new_poly_matrix[i][1] < 1000 and new_poly_matrix[i][0] >= 0 and new_poly_matrix[i][1] >= 0: warps[new_poly_matrix[i][1]][new_poly_matrix[i][0]] = raw_query_image[raw[i][1]][raw[i][0]] plt.imshow(warps,cmap='gray') #plt.scatter(new_poly_matrix[:,0], new_poly_matrix[:,1],s=0.1) plt.show() plt.close() #print(mask.sum(),len(mask.ravel())) flag = True return warps, raw_query_image, raw_refer_image, flag def align_image_pair_regression(self, query_path, refer_path, show=False, crop=False): self.name = os.path.split(refer_path)[-1][:-4] if not os.path.exists(os.path.join(self.savepath,self.name)): os.mkdir(os.path.join(self.savepath,self.name)) #print(self.name) warps, raw_query_image, raw_refer_image,flag = self.regression(query_path, refer_path, show=show, crop=crop) if not flag: name = os.path.split(query_path)[-1] print(f"Matching Failed for {name}") shutil.rmtree(os.path.join(self.savepath,self.name)) return #print(inliers_num_rate) if warps is not None: h, w = self.image_height, self.image_width merged = np.zeros((h, w, 3), dtype=np.uint8) if len(raw_refer_image.shape) == 3: refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY) else: refer_gray = raw_refer_image merged[:, :, 0] = warps merged[:, :, 1] = refer_gray if show: fig = plt.figure(dpi=300, facecolor='black') ax = fig.add_subplot(111) ax.set_facecolor('black') ax.imshow(merged) ax.axis('off') ax.set_title('Registration Result', color='white') fig.savefig( os.path.join(self.savepath, self.name, 'merged.png'), dpi=300, facecolor=fig.get_facecolor(), bbox_inches='tight', pad_inches=0, ) plt.show() plt.close(fig) return merged print("Matched Failed!") def align_image_pair(self, query_path, refer_path, show=False, crop=False): self.name = os.path.split(refer_path)[-1][:-4] if not os.path.exists(os.path.join(self.savepath,self.name)): os.mkdir(os.path.join(self.savepath,self.name)) #print(self.name) H_m, inliers_num_rate, raw_query_image, raw_refer_image,flag = self.compute_homography(query_path, refer_path, show=show, crop=crop) if not flag: name = os.path.split(query_path)[-1] print(f"Matching Failed for {name}") shutil.rmtree(os.path.join(self.savepath,self.name)) return #print(inliers_num_rate) if H_m is not None: h, w = self.image_height, self.image_width query_align = cv2.warpPerspective(raw_query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=(0)) merged = np.zeros((h, w, 3), dtype=np.uint8) if len(query_align.shape) == 3: query_align = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY) if len(raw_refer_image.shape) == 3: refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY) else: refer_gray = raw_refer_image merged[:, :, 0] = query_align merged[:, :, 1] = refer_gray if True: fig = plt.figure(dpi=300, facecolor='black') ax = fig.add_subplot(111) ax.set_facecolor('black') ax.imshow(merged) ax.axis('off') ax.set_title('Registration Result', color='white') plt.show() plt.close(fig) return merged print("Matched Failed!") def model_run_one_image(self, image_path, save_path=None): image = cv2.imread(image_path, cv2.IMREAD_COLOR) image = image[:, :, 1] self.image_height, self.image_width = image.shape[:2] image = pre_processing(image) image_tensor = self.trasformer(Image.fromarray(image)) inputs = image_tensor.unsqueeze(0) inputs = inputs.to(self.device) with torch.no_grad(): detector_pred, descriptor_pred = self.model(inputs) scores = simple_nms(detector_pred, self.nms_size) b, _, h, w = detector_pred.shape scores = scores.reshape(-1, h, w) keypoints = [ torch.nonzero(s > self.nms_thresh) for s in scores] scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] # Discard keypoints near the image borders keypoints, scores = list(zip(*[ remove_borders(k, s, 4, h, w) for k, s in zip(keypoints, scores)])) keypoints = [torch.flip(k, [1]).float().data for k in keypoints] descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu() for k, d in zip(keypoints, descriptor_pred)] keypoints = [k.cpu() for k in keypoints] if save_path is not None: save_info = {'kp': keypoints[0].cpu()/self.model_image_width, 'desc': descriptors[0].cpu(), 'scores': scores} torch.save(save_info, save_path) print("Successfully saved! ",keypoints[0].shape) return keypoints[0], descriptors[0] def homography_from_tensor(self, query_info, refer_info): query_keypoints, query_desc = query_info['kp'], query_info['desc'] refer_keypoints, refer_desc = refer_info['kp'], refer_info['desc'] query_desc = query_desc.permute(1, 0).numpy() refer_desc = refer_desc.permute(1, 0).numpy() cv_kpts_query = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), int(i[1] / self.model_image_height * self.image_height), 30) for i in query_keypoints] cv_kpts_refer = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width), int(i[1] / self.model_image_height * self.image_height), 30) for i in refer_keypoints] goodMatch = [] status = [] try: matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2) for m, n in matches: if m.distance < self.knn_thresh * n.distance: goodMatch.append(m) status.append(True) else: status.append(False) except Exception: pass H_m = None inliers_num = 0 if len(goodMatch) >= 4: src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch] src_pts = np.float32(src_pts).reshape(-1, 1, 2) dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch] dst_pts = np.float32(dst_pts).reshape(-1, 1, 2) H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS) # src_pts = src_pts[mask.ravel() == 1] # dst_pts = dst_pts[mask.ravel() == 1] goodMatch = np.array(goodMatch)[mask.ravel() == 1] inliers_num = mask.sum() return H_m, inliers_num if __name__ == '__main__': import yaml config_path = 'config/test.yaml' if os.path.exists(config_path): with open(config_path) as f: config = yaml.safe_load(f) else: raise FileNotFoundError("Config File doesn't Exist") savepath = './data/bad_images/result' f2 = r'./data/Crop_vessel_test20/Images/1001_Z2203_a0.jpg' f1 = r'./data/Crop_vessel_test20/Images/1001_Z2203_c1.jpg' f1 = './data/FIRE/Images/P35_1.jpg' f2 = './data/FIRE/Images/P35_2.jpg' #f1 = './data/samples/1007_Z2204_a2.jpg' #f2 = './data/samples/1007_Z2204_c0.jpg' #f2 = './data/OCTA-1/Images/8_OD_FP_vessel.jpg' #f1 = './data/OCTA-1/OCTA/8_OD_OCTA_vessel.jpg' #f1 = './data/bad_images/octa/shandong-mm-phase1-00170-20230901-OS_001_vessel.jpg' #f2 = './data/bad_images/cfp/shandong-mm-phase1-00170-20230901-OS.fundus.jpg' savepath = './data/bad_images/result' #P.match(f1, f2, show=True, crop=True) img_path = './data/bad_images/' P = Predictor(config, img_path, savepath) P.align_image_pair_regression(f1, f2, crop=False, show=True) ''' octa_imgs = os.listdir(os.path.join(img_path,'octa')) cfp_imgs = os.listdir(os.path.join(img_path,'cfp')) #print(cfp_imgs) for img in cfp_imgs: #if os.path.exists(os.path.join(savepath,img[:-4])): #print("pass") #continue f2 = os.path.join(img_path,'cfp',img) octa = [i for i in octa_imgs if img[:-11] in i] f1 = os.path.join(img_path,'octa',octa[0]) #print(f1,f2) #try: P.align_image_pair(f1, f2, crop=True, show=True) #except Exception as e: #shutil.rmtree(os.path.join(savepath,img[:-4])) #print(e)''' ''' #export keys, desc path = './data/OCTA-1/Crop_Images/' #path = './data/freiburg_sequence/' files = os.listdir(path) for file in files: if '.jpg' in file: f1 = path + file P.model_run_one_image(f1, f1[:-4]+'.pt')'''