CARe / Step3_Reg /predictor_server_regression.py
Hongyang-Li's picture
Upload 78 files
ffba4ae verified
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 24 13:57:44 2024
@author: tjy
"""
import configparser
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import sys
import os
STEP3_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.join(STEP3_DIR, "Src")
if SRC_DIR not in sys.path:
sys.path.insert(0, SRC_DIR)
from common.common_util import pre_processing, simple_nms, remove_borders, \
sample_keypoint_desc
from model.super_retina import SuperRetina
from PIL import Image
import re
import shutil
import pandas as pd
from sklearn.linear_model import LinearRegression #import linear regression module
from sklearn.preprocessing import PolynomialFeatures
class Predictor:
def __init__(self, config, img_path=None, savepath=None):
predict_config = config['PREDICT']
device = predict_config['device']
device = torch.device(device if torch.cuda.is_available() else "cpu")
model_save_path = predict_config['model_save_path']
self.nms_size = predict_config['nms_size']
self.nms_thresh = predict_config['nms_thresh']
self.scale = 8
self.knn_thresh = predict_config['knn_thresh']
self.image_width = None
self.image_height = None
self.model_image_width = predict_config['model_image_width']
self.model_image_height = predict_config['model_image_height']
checkpoint = torch.load(model_save_path, map_location=device)
model = SuperRetina()
model.load_state_dict(checkpoint['net'])
model.to(device)
model.eval()
self.device = device
self.model = model
self.knn_matcher = cv2.BFMatcher(cv2.NORM_L2)
self.trasformer = transforms.Compose([
transforms.Resize((self.model_image_height, self.model_image_width)),
transforms.ToTensor(),
])
self.crop = None
self.refer_crop = None
self.query_crop = None
self.crop_path = None
self.query_crop_path = None
self.crop_mode = "refer_box_only"
self.savepath = savepath
self.name = None
self.img_path = img_path
self.last_match_debug = {}
def _load_crop_box(self, crop_path, name):
crop_size = np.asarray(np.loadtxt(crop_path), dtype=np.float32)
crop_size = np.atleast_2d(crop_size)
if crop_size.shape != (2, 2):
raise ValueError("{} invalid format: {}".format(name, crop_path))
crop_size = np.rint(crop_size).astype(np.int32)
crop_size[:, 0] = np.clip(crop_size[:, 0], 0, 1000)
crop_size[:, 1] = np.clip(crop_size[:, 1], 0, 1000)
if crop_size[1][0] <= crop_size[0][0] or crop_size[1][1] <= crop_size[0][1]:
raise ValueError("{} invalid: {}".format(name, crop_size.tolist()))
return crop_size
def image_read(self, query_path, refer_path, query_is_image=False, crop=False):
if query_is_image:
query_image = query_path
else:
query_image = cv2.imread(query_path, cv2.IMREAD_COLOR)
# green channel
query_image = query_image[:, :, 1]
query_image = pre_processing(query_image)
refer_image = cv2.imread(refer_path, cv2.IMREAD_COLOR)
refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR)
query_image = cv2.resize(query_image,(1000,1000),interpolation=cv2.INTER_LINEAR)
self.query_crop = None
self.refer_crop = None
self.crop = None
if crop:
crop_mode = getattr(self, "crop_mode", "refer_box_only")
if self.crop_path is not None and os.path.exists(self.crop_path):
refer_crop_size = self._load_crop_box(self.crop_path, "refer_crop_box_1000")
else:
cfp_od_fovea = np.loadtxt(refer_path[:-11] + '.txt')
octa_od_fovea = np.loadtxt(query_path[:-11] + '.txt')
octa_len = octa_od_fovea[1][1] - octa_od_fovea[0][1]
if octa_len < 0:
octa_len = -octa_len
cfp_len = cfp_od_fovea[1][1] - cfp_od_fovea[0][1]
if cfp_len < 0:
cfp_len = -cfp_len
rate = cfp_len / octa_len
x0 = max(cfp_od_fovea[1][1] * 1000 - octa_od_fovea[1][1] * 1000 * rate, 0)
x1 = min(cfp_od_fovea[1][1] * 1000 + octa_od_fovea[1][1] * 1000 * rate, 1000)
y0 = max(cfp_od_fovea[1][0] * 1000 - octa_od_fovea[1][0] * 1000 * rate, 0)
y1 = min(cfp_od_fovea[1][0] * 1000 + octa_od_fovea[1][0] * 1000 * rate, 1000)
# Keep behavior consistent with original test_on_OCTA_regression*.py:
# Compute crop using float first, then apply int() truncation at slicing (no round).
refer_crop_size = [[y0, x0], [y1, x1]]
if refer_crop_size[1][0] <= refer_crop_size[0][0] or refer_crop_size[1][1] <= refer_crop_size[0][1]:
raise ValueError("crop_size invalid: {}".format(refer_crop_size))
refer_image = refer_image[
int(refer_crop_size[0][0]):int(refer_crop_size[1][0]),
int(refer_crop_size[0][1]):int(refer_crop_size[1][1]),
]
self.refer_crop = refer_crop_size
self.crop = refer_crop_size
if crop_mode == "dual_crop":
if self.query_crop_path is None or not os.path.exists(self.query_crop_path):
raise ValueError("dual_crop mode is missing a valid query_crop_box_1000 path")
query_crop_size = self._load_crop_box(self.query_crop_path, "query_crop_box_1000")
query_image = query_image[
int(query_crop_size[0][0]):int(query_crop_size[1][0]),
int(query_crop_size[0][1]):int(query_crop_size[1][1]),
]
query_image = cv2.resize(query_image, (1000, 1000), interpolation=cv2.INTER_LINEAR)
self.query_crop = query_crop_size
refer_image = refer_image[:, :, 1]
#print(refer_image.shape,query_image.shape)
refer_image = pre_processing(refer_image)
refer_image = cv2.resize(refer_image,(1000,1000),interpolation=cv2.INTER_LINEAR)
assert query_image.shape[:2] == refer_image.shape[:2]
self.image_height, self.image_width = query_image.shape[:2]
refer_image = np.where(refer_image > 0.00, 1 - refer_image, refer_image)
if not query_is_image:
query_image = np.where(query_image > 0.00, 1 - query_image, query_image)
query_image = (query_image * 255).astype(np.uint8)
refer_image = (refer_image * 255).astype(np.uint8)
return query_image, refer_image
def _coordinate_transform(self, keypoints, crop_size):
old_width = crop_size[1][1] - crop_size[0][1]
old_height = crop_size[1][0] - crop_size[0][0]
resize_width = old_width / self.model_image_width
resize_height = old_height / self.model_image_height
return [
cv2.KeyPoint(
int((i[0] * resize_width + crop_size[0][1]) * self.image_width / 1000),
int((i[1] * resize_height + crop_size[0][0]) * self.image_height / 1000),
30,
)
for i in keypoints
]
def coordinate_transform(self, refer_keypoints):
return self._coordinate_transform(refer_keypoints, self.refer_crop)
def query_coordinate_transform(self, query_keypoints):
return self._coordinate_transform(query_keypoints, self.query_crop)
def _ensure_case_dir(self):
case_dir = os.path.join(self.savepath, self.name)
if not os.path.exists(case_dir):
os.mkdir(case_dir)
return case_dir
def _points_from_matches(self, query_keypoints, refer_keypoints, matches):
if len(matches) == 0:
empty = np.zeros((0, 2), dtype=np.float32)
return empty, empty
query_points = np.array(
[query_keypoints[m.queryIdx].pt for m in matches],
dtype=np.float32,
)
refer_points = np.array(
[refer_keypoints[m.trainIdx].pt for m in matches],
dtype=np.float32,
)
return query_points, refer_points
def _save_pair_points(self, save_path, query_points, refer_points, normalize=False):
if len(query_points) == 0:
pairs = np.zeros((0, 4), dtype=np.float64)
else:
pairs = np.column_stack((query_points, refer_points)).astype(np.float64)
if normalize:
pairs = pairs / float(self.image_height)
np.savetxt(save_path, pairs, fmt="%.12f")
return save_path
def _to_bgr(self, image):
if len(image.shape) == 2:
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image.copy()
def _save_match_visualization(self, query_image, refer_image, query_points, refer_points, save_path, title):
query_bgr = self._to_bgr(query_image)
refer_bgr = self._to_bgr(refer_image)
h_a, w_a = query_bgr.shape[:2]
h_b, w_b = refer_bgr.shape[:2]
canvas = np.zeros((max(h_a, h_b), w_a + w_b, 3), dtype=np.uint8)
canvas[0:h_a, 0:w_a] = query_bgr
canvas[0:h_b, w_a:w_a + w_b] = refer_bgr
for query_pt, refer_pt in zip(query_points, refer_points):
pt_a = (int(round(query_pt[0])), int(round(query_pt[1])))
pt_b = (int(round(refer_pt[0] + w_a)), int(round(refer_pt[1])))
cv2.line(canvas, pt_a, pt_b, (0, 255, 0), 1, lineType=cv2.LINE_AA)
cv2.circle(canvas, pt_a, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA)
cv2.circle(canvas, pt_b, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA)
title_height = 70
vis = np.zeros((canvas.shape[0] + title_height, canvas.shape[1], 3), dtype=np.uint8)
vis[title_height:, :] = canvas
cv2.putText(
vis,
title,
(20, 45),
cv2.FONT_HERSHEY_SIMPLEX,
1.2,
(255, 255, 255),
2,
lineType=cv2.LINE_AA,
)
cv2.imwrite(save_path, vis)
return save_path
def draw_result(self, query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status, crop=False):
def drawMatches(imageA, imageB, kpsA, kpsB, matches, status):
# initialize the output visualization image
(hA, wA) = imageA.shape[:2]
(hB, wB) = imageB.shape[:2]
#hA += 500
vis = np.zeros((max(hA,hB), wA+wB, 3), dtype="uint8")
#print(imageA.shape)
if len(imageA.shape) == 2:
imageA = cv2.cvtColor(imageA, cv2.COLOR_GRAY2RGB)
imageB = cv2.cvtColor(imageB, cv2.COLOR_GRAY2RGB)
vis[0:hA, 0:wA] = imageA
vis[0:hB, wA:] = imageB
pt = []
# loop over the matches
for (match, _), s in zip(matches, status):
trainIdx, queryIdx = match.trainIdx, match.queryIdx
# only process the match if the keypoint was successfully
# matched
if s == 1:
# draw the match
ptA = (int(kpsA[queryIdx].pt[0]), int(kpsA[queryIdx].pt[1]))
ptB = (int(kpsB[trainIdx].pt[0])+wA, int(kpsB[trainIdx].pt[1]))
#save matched points
pt.append([kpsA[queryIdx].pt[0],kpsA[queryIdx].pt[1],kpsB[trainIdx].pt[0],kpsB[trainIdx].pt[1]])
cv2.line(vis, ptA, ptB, (0, 255, 0), 2)
pt = np.array(pt)/hA
np.savetxt(os.path.join(self.savepath,self.name,'pairs_kps.txt'),pt)
# return the visualization
return vis
query_np = np.array([kp.pt for kp in cv_kpts_query])
refer_np = np.array([kp.pt for kp in cv_kpts_refer])
refer_np[:, 0] += query_image.shape[1]
#show on raw images
matched_image = drawMatches(query_image, refer_image, cv_kpts_query, cv_kpts_refer, matches, status)
fig = plt.figure(dpi=300, facecolor='black')
ax = fig.add_subplot(111)
ax.set_facecolor('black')
ax.scatter(query_np[:, 0], query_np[:, 1], s=0.1, c='r')
ax.scatter(refer_np[:, 0], refer_np[:, 1], s=0.02, c='r')
ax.axis('off')
ax.set_title('Match Result, #goodMatch: {}'.format(status.sum()), color='white')
ax.imshow(cv2.cvtColor(matched_image, cv2.COLOR_BGR2RGB))
fig.savefig(
os.path.join(self.savepath, self.name, 'paired_kps.png'),
dpi=300,
facecolor=fig.get_facecolor(),
bbox_inches='tight',
pad_inches=0,
)
plt.show()
plt.close(fig)
def model_run_pair(self, query_tensor, refer_tensor):
inputs = torch.cat((query_tensor.unsqueeze(0), refer_tensor.unsqueeze(0)))
inputs = inputs.to(self.device)
with torch.no_grad():
detector_pred, descriptor_pred = self.model(inputs)
scores = simple_nms(detector_pred, self.nms_size)
b, _, h, w = detector_pred.shape
scores = scores.reshape(-1, h, w)
keypoints = [
torch.nonzero(s > self.nms_thresh)
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, 4, h, w)
for k, s in zip(keypoints, scores)]))
keypoints = [torch.flip(k, [1]).float().data for k in keypoints]
#print(query_tensor.shape)
descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu()
for k, d in zip(keypoints, descriptor_pred)]
keypoints = [k.cpu() for k in keypoints]
#print(keypoints[0].shape,keypoints[1].shape)
Data_type = object
keypoints = np.array(keypoints,dtype=Data_type)
#print(keypoints/query_tensor.shape[1])
#torch.save([keypoints/query_tensor.shape[1], descriptors, scores],'./OCTA_test.pt')
return keypoints, descriptors
def match_scores(self, H_m, size, ngoodmatch):
if ngoodmatch < 15:
return False
points = [[0,0],[0,size],[size,0],[size,size]]
new_pts = []
for p in points:
x11 = (p[0]*H_m[0][0] + p[1]*H_m[0][1] + H_m[0][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2])
y11 = (p[0]*H_m[1][0] + p[1]*H_m[1][1] + H_m[1][2])/(p[0]*H_m[2][0] + p[1]*H_m[2][1] + H_m[2][2])
if x11 > size or y11 > size:
return False
new_pts.append([x11,y11])
if new_pts[1][1] < new_pts[0][1] or new_pts[2][0] < new_pts[0][0] or new_pts[3][0] < new_pts[1][0] or new_pts[3][1] < new_pts[2][1]:
return False
return True
def match(self, query_path, refer_path, show=False, query_is_image=False, crop=False, save_match_images=True):
query_image, refer_image = self.image_read(query_path, refer_path, query_is_image, crop)
query_tensor = self.trasformer(Image.fromarray(query_image))
refer_tensor = self.trasformer(Image.fromarray(refer_image))
keypoints, descriptors = self.model_run_pair(query_tensor, refer_tensor)
#print(keypoints[2])
query_keypoints, refer_keypoints = keypoints[0], keypoints[1]
query_desc, refer_desc = descriptors[0].permute(1, 0).numpy(), descriptors[1].permute(1, 0).numpy()
#print(query_keypoints.shape)
# mapping keypoints to scaled keypoints
##################################################
cv_kpts_query_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width),
int(i[1] / self.model_image_height * self.image_height), 30)
for i in query_keypoints]
cv_kpts_refer_input = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width),
int(i[1] / self.model_image_height * self.image_height), 30)
for i in refer_keypoints]
query_image_input = query_image.copy()
refer_image_input = refer_image.copy()
cv_kpts_query = cv_kpts_query_input
cv_kpts_refer = cv_kpts_refer_input
query_image_restored = query_image_input
refer_image_restored = refer_image_input
if crop:
crop_mode = getattr(self, "crop_mode", "refer_box_only")
if crop_mode == "dual_crop":
cv_kpts_query = self.query_coordinate_transform(query_keypoints)
cv_kpts_refer = self.coordinate_transform(refer_keypoints)
query_image_restored, refer_image_restored = self.image_read(query_path, refer_path)
goodMatch = []
status = []
matches = []
try:
matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2)
#print(matches.shape)
for m, n in matches:
if m.distance < self.knn_thresh * n.distance:
goodMatch.append(m)
status.append(True)
else:
status.append(False)
except Exception:
pass
case_dir = self._ensure_case_dir()
query_points_input, refer_points_input = self._points_from_matches(
cv_kpts_query_input,
cv_kpts_refer_input,
goodMatch,
)
query_points_restored, refer_points_restored = self._points_from_matches(
cv_kpts_query,
cv_kpts_refer,
goodMatch,
)
if save_match_images:
cv2.imwrite(os.path.join(case_dir, "match_query_input.png"), query_image_input)
cv2.imwrite(os.path.join(case_dir, "match_refer_input.png"), refer_image_input)
cv2.imwrite(os.path.join(case_dir, "match_query_restored.png"), query_image_restored)
cv2.imwrite(os.path.join(case_dir, "match_refer_restored.png"), refer_image_restored)
self._save_pair_points(
os.path.join(case_dir, "pairs_kps_input_space_pixels.txt"),
query_points_input,
refer_points_input,
normalize=False,
)
self._save_pair_points(
os.path.join(case_dir, "pairs_kps_restored_space_pixels.txt"),
query_points_restored,
refer_points_restored,
normalize=False,
)
self._save_pair_points(
os.path.join(case_dir, "pairs_kps.txt"),
query_points_restored,
refer_points_restored,
normalize=True,
)
self.last_match_debug = {
"query_image_input": query_image_input,
"refer_image_input": refer_image_input,
"query_image_restored": query_image_restored,
"refer_image_restored": refer_image_restored,
"query_points_input": query_points_input,
"refer_points_input": refer_points_input,
"query_points_restored": query_points_restored,
"refer_points_restored": refer_points_restored,
"good_match_count": int(len(goodMatch)),
}
if show:
self._save_match_visualization(
query_image_input,
refer_image_input,
query_points_input,
refer_points_input,
os.path.join(case_dir, "descriptor_matches_input_space.png"),
"Descriptor Match (Input Space), #goodMatch: {}".format(len(goodMatch)),
)
self._save_match_visualization(
query_image_restored,
refer_image_restored,
query_points_restored,
refer_points_restored,
os.path.join(case_dir, "paired_kps.png"),
"Descriptor Match (Restored Space), #goodMatch: {}".format(len(goodMatch)),
)
return goodMatch, cv_kpts_query, cv_kpts_refer, query_image_restored, refer_image_restored
def compute_homography(self, query_path, refer_path, query_is_image=False, show=False, crop=False):
goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \
self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop)
H_m = None
inliers_num_rate = 0
if len(goodMatch) >= 4:
src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch]
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch]
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
#print(src_pts)
H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS)
# src_pts = src_pts[mask.ravel() == 1]
# dst_pts = dst_pts[mask.ravel() == 1]
goodMatch = np.array(goodMatch)[mask.ravel() == 1]
inliers_num_rate = mask.sum() / len(mask.ravel())
#print(mask.sum(),len(mask.ravel()))
flag = True
#if not self.match_scores(H_m, 1000, len(goodMatch)):
#flag = False
return H_m, inliers_num_rate, raw_query_image, raw_refer_image, flag
def regression(self, query_path, refer_path, query_is_image=False, show=False, crop=False):
goodMatch, cv_kpts_query, cv_kpts_refer, raw_query_image, raw_refer_image = \
self.match(query_path, refer_path, query_is_image=query_is_image, show=show, crop=crop)
if len(goodMatch) >= 4:
src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch]
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch]
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
print(len(goodMatch))
H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=10.0)
if mask is None:
return None, raw_query_image, raw_refer_image, False
# src_pts = src_pts[mask.ravel() == 1]
# dst_pts = dst_pts[mask.ravel() == 1]
inlier_mask = mask.ravel().astype(bool)
goodMatch = np.array(goodMatch)[inlier_mask]
print(len(goodMatch))
case_dir = self._ensure_case_dir()
if self.last_match_debug and len(self.last_match_debug.get("query_points_restored", [])) == len(inlier_mask):
query_points_input = self.last_match_debug["query_points_input"][inlier_mask]
refer_points_input = self.last_match_debug["refer_points_input"][inlier_mask]
query_points_restored = self.last_match_debug["query_points_restored"][inlier_mask]
refer_points_restored = self.last_match_debug["refer_points_restored"][inlier_mask]
self._save_pair_points(
os.path.join(case_dir, "ransac_inlier_pairs_input_space_pixels.txt"),
query_points_input,
refer_points_input,
normalize=False,
)
self._save_pair_points(
os.path.join(case_dir, "ransac_inlier_pairs_restored_space_pixels.txt"),
query_points_restored,
refer_points_restored,
normalize=False,
)
self._save_match_visualization(
self.last_match_debug["query_image_input"],
self.last_match_debug["refer_image_input"],
query_points_input,
refer_points_input,
os.path.join(case_dir, "ransac_inlier_pairs_input_space.png"),
"RANSAC Inlier Match (Input Space), #inlier: {}".format(len(query_points_input)),
)
self._save_match_visualization(
self.last_match_debug["query_image_restored"],
self.last_match_debug["refer_image_restored"],
query_points_restored,
refer_points_restored,
os.path.join(case_dir, "ransac_inlier_pairs_restored_space.png"),
"RANSAC Inlier Match (Restored Space), #inlier: {}".format(len(query_points_restored)),
)
if len(goodMatch) < 4:
return None, raw_query_image, raw_refer_image, False
src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch]
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch]
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
else:
return None, raw_query_image, raw_refer_image, False
src_pts = src_pts[:,0]
dst_pts = dst_pts[:,0]
src_pts = src_pts.tolist()
for index in range(2,3):
#print(src_pts,dst_pts)
data_X = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,0]})
data_train_X = np.array(data_X['IN']).reshape(data_X['IN'].shape[0],1)
data_test_X = data_X['OUT']
poly_reg_X = PolynomialFeatures(degree = index)
X_ploy = poly_reg_X.fit_transform(src_pts)
regr_X = LinearRegression()
regr_X.fit(X_ploy,data_test_X)
if(regr_X.score(X_ploy,data_test_X) >= 0.99):
break
for index in range(2,3):
data_Y = pd.DataFrame({'IN':src_pts, 'OUT':dst_pts[:,1]})
data_train_Y = np.array(data_Y['IN']).reshape(data_Y['IN'].shape[0],1)
data_test_Y = data_Y['OUT']
poly_reg_Y = PolynomialFeatures(degree = index)
Y_ploy = poly_reg_Y.fit_transform(src_pts)
regr_Y = LinearRegression()
regr_Y.fit(Y_ploy,data_test_Y)
if(regr_Y.score(Y_ploy,data_test_Y) >= 0.99):
break
raw = [[x,y] for x in range(1000) for y in range(1000)]
#raw = src_pts
#raw = [[0,1],[2,40],[5,500]]
raw = np.array(raw)
#print(raw)
new_data_poly_X = poly_reg_X.fit_transform(raw)
#print(new_data_poly_X)
predicted_values_X = regr_X.predict(new_data_poly_X)
new_data_poly_Y = poly_reg_Y.fit_transform(raw)
predicted_values_Y = regr_Y.predict(new_data_poly_Y)
polynomial_matrix = np.column_stack((predicted_values_X, predicted_values_Y))
#print(polynomial_matrix)
new_poly_matrix = np.array(np.around(polynomial_matrix), dtype=int)
warps = np.zeros((self.image_height,self.image_width), dtype=np.uint8)
for i in range(len(new_poly_matrix)):
if new_poly_matrix[i][0] < 1000 and new_poly_matrix[i][1] < 1000 and new_poly_matrix[i][0] >= 0 and new_poly_matrix[i][1] >= 0:
warps[new_poly_matrix[i][1]][new_poly_matrix[i][0]] = raw_query_image[raw[i][1]][raw[i][0]]
plt.imshow(warps,cmap='gray')
#plt.scatter(new_poly_matrix[:,0], new_poly_matrix[:,1],s=0.1)
plt.show()
plt.close()
#print(mask.sum(),len(mask.ravel()))
flag = True
return warps, raw_query_image, raw_refer_image, flag
def align_image_pair_regression(self, query_path, refer_path, show=False, crop=False):
self.name = os.path.split(refer_path)[-1][:-4]
if not os.path.exists(os.path.join(self.savepath,self.name)):
os.mkdir(os.path.join(self.savepath,self.name))
#print(self.name)
warps, raw_query_image, raw_refer_image,flag = self.regression(query_path, refer_path, show=show, crop=crop)
if not flag:
name = os.path.split(query_path)[-1]
print(f"Matching Failed for {name}")
shutil.rmtree(os.path.join(self.savepath,self.name))
return
#print(inliers_num_rate)
if warps is not None:
h, w = self.image_height, self.image_width
merged = np.zeros((h, w, 3), dtype=np.uint8)
if len(raw_refer_image.shape) == 3:
refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY)
else:
refer_gray = raw_refer_image
merged[:, :, 0] = warps
merged[:, :, 1] = refer_gray
if show:
fig = plt.figure(dpi=300, facecolor='black')
ax = fig.add_subplot(111)
ax.set_facecolor('black')
ax.imshow(merged)
ax.axis('off')
ax.set_title('Registration Result', color='white')
fig.savefig(
os.path.join(self.savepath, self.name, 'merged.png'),
dpi=300,
facecolor=fig.get_facecolor(),
bbox_inches='tight',
pad_inches=0,
)
plt.show()
plt.close(fig)
return merged
print("Matched Failed!")
def align_image_pair(self, query_path, refer_path, show=False, crop=False):
self.name = os.path.split(refer_path)[-1][:-4]
if not os.path.exists(os.path.join(self.savepath,self.name)):
os.mkdir(os.path.join(self.savepath,self.name))
#print(self.name)
H_m, inliers_num_rate, raw_query_image, raw_refer_image,flag = self.compute_homography(query_path, refer_path, show=show, crop=crop)
if not flag:
name = os.path.split(query_path)[-1]
print(f"Matching Failed for {name}")
shutil.rmtree(os.path.join(self.savepath,self.name))
return
#print(inliers_num_rate)
if H_m is not None:
h, w = self.image_height, self.image_width
query_align = cv2.warpPerspective(raw_query_image, H_m, (w, h), borderMode=cv2.BORDER_CONSTANT,
borderValue=(0))
merged = np.zeros((h, w, 3), dtype=np.uint8)
if len(query_align.shape) == 3:
query_align = cv2.cvtColor(query_align, cv2.COLOR_BGR2GRAY)
if len(raw_refer_image.shape) == 3:
refer_gray = cv2.cvtColor(raw_refer_image, cv2.COLOR_BGR2GRAY)
else:
refer_gray = raw_refer_image
merged[:, :, 0] = query_align
merged[:, :, 1] = refer_gray
if True:
fig = plt.figure(dpi=300, facecolor='black')
ax = fig.add_subplot(111)
ax.set_facecolor('black')
ax.imshow(merged)
ax.axis('off')
ax.set_title('Registration Result', color='white')
plt.show()
plt.close(fig)
return merged
print("Matched Failed!")
def model_run_one_image(self, image_path, save_path=None):
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = image[:, :, 1]
self.image_height, self.image_width = image.shape[:2]
image = pre_processing(image)
image_tensor = self.trasformer(Image.fromarray(image))
inputs = image_tensor.unsqueeze(0)
inputs = inputs.to(self.device)
with torch.no_grad():
detector_pred, descriptor_pred = self.model(inputs)
scores = simple_nms(detector_pred, self.nms_size)
b, _, h, w = detector_pred.shape
scores = scores.reshape(-1, h, w)
keypoints = [
torch.nonzero(s > self.nms_thresh)
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, 4, h, w)
for k, s in zip(keypoints, scores)]))
keypoints = [torch.flip(k, [1]).float().data for k in keypoints]
descriptors = [sample_keypoint_desc(k[None], d[None], 8)[0].cpu()
for k, d in zip(keypoints, descriptor_pred)]
keypoints = [k.cpu() for k in keypoints]
if save_path is not None:
save_info = {'kp': keypoints[0].cpu()/self.model_image_width, 'desc': descriptors[0].cpu(), 'scores': scores}
torch.save(save_info, save_path)
print("Successfully saved! ",keypoints[0].shape)
return keypoints[0], descriptors[0]
def homography_from_tensor(self, query_info, refer_info):
query_keypoints, query_desc = query_info['kp'], query_info['desc']
refer_keypoints, refer_desc = refer_info['kp'], refer_info['desc']
query_desc = query_desc.permute(1, 0).numpy()
refer_desc = refer_desc.permute(1, 0).numpy()
cv_kpts_query = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width),
int(i[1] / self.model_image_height * self.image_height), 30)
for i in query_keypoints]
cv_kpts_refer = [cv2.KeyPoint(int(i[0] / self.model_image_width * self.image_width),
int(i[1] / self.model_image_height * self.image_height), 30)
for i in refer_keypoints]
goodMatch = []
status = []
try:
matches = self.knn_matcher.knnMatch(query_desc, refer_desc, k=2)
for m, n in matches:
if m.distance < self.knn_thresh * n.distance:
goodMatch.append(m)
status.append(True)
else:
status.append(False)
except Exception:
pass
H_m = None
inliers_num = 0
if len(goodMatch) >= 4:
src_pts = [cv_kpts_query[m.queryIdx].pt for m in goodMatch]
src_pts = np.float32(src_pts).reshape(-1, 1, 2)
dst_pts = [cv_kpts_refer[m.trainIdx].pt for m in goodMatch]
dst_pts = np.float32(dst_pts).reshape(-1, 1, 2)
H_m, mask = cv2.findHomography(src_pts, dst_pts, cv2.LMEDS)
# src_pts = src_pts[mask.ravel() == 1]
# dst_pts = dst_pts[mask.ravel() == 1]
goodMatch = np.array(goodMatch)[mask.ravel() == 1]
inliers_num = mask.sum()
return H_m, inliers_num
if __name__ == '__main__':
import yaml
config_path = 'config/test.yaml'
if os.path.exists(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
else:
raise FileNotFoundError("Config File doesn't Exist")
savepath = './data/bad_images/result'
f2 = r'./data/Crop_vessel_test20/Images/1001_Z2203_a0.jpg'
f1 = r'./data/Crop_vessel_test20/Images/1001_Z2203_c1.jpg'
f1 = './data/FIRE/Images/P35_1.jpg'
f2 = './data/FIRE/Images/P35_2.jpg'
#f1 = './data/samples/1007_Z2204_a2.jpg'
#f2 = './data/samples/1007_Z2204_c0.jpg'
#f2 = './data/OCTA-1/Images/8_OD_FP_vessel.jpg'
#f1 = './data/OCTA-1/OCTA/8_OD_OCTA_vessel.jpg'
#f1 = './data/bad_images/octa/shandong-mm-phase1-00170-20230901-OS_001_vessel.jpg'
#f2 = './data/bad_images/cfp/shandong-mm-phase1-00170-20230901-OS.fundus.jpg'
savepath = './data/bad_images/result'
#P.match(f1, f2, show=True, crop=True)
img_path = './data/bad_images/'
P = Predictor(config, img_path, savepath)
P.align_image_pair_regression(f1, f2, crop=False, show=True)
'''
octa_imgs = os.listdir(os.path.join(img_path,'octa'))
cfp_imgs = os.listdir(os.path.join(img_path,'cfp'))
#print(cfp_imgs)
for img in cfp_imgs:
#if os.path.exists(os.path.join(savepath,img[:-4])):
#print("pass")
#continue
f2 = os.path.join(img_path,'cfp',img)
octa = [i for i in octa_imgs if img[:-11] in i]
f1 = os.path.join(img_path,'octa',octa[0])
#print(f1,f2)
#try:
P.align_image_pair(f1, f2, crop=True, show=True)
#except Exception as e:
#shutil.rmtree(os.path.join(savepath,img[:-4]))
#print(e)'''
'''
#export keys, desc
path = './data/OCTA-1/Crop_Images/'
#path = './data/freiburg_sequence/'
files = os.listdir(path)
for file in files:
if '.jpg' in file:
f1 = path + file
P.model_run_one_image(f1, f1[:-4]+'.pt')'''