| """Implements keypoint matching for a pair of images.""" |
| import os |
| import numpy as np |
| import PIL |
| import cv2 |
| import matplotlib.pyplot as plt |
|
|
|
|
| def show_single_image(img, figsize=(7, 5), title="Single image"): |
| """Displays a single image.""" |
| fig = plt.figure(figsize=figsize) |
| plt.axis("off") |
| plt.imshow(img) |
| plt.title(title) |
| plt.show() |
|
|
|
|
| def show_two_images(img1, img2, title="Two images"): |
| """Displays a pair of images.""" |
| fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True) |
|
|
| ax[0].axis("off") |
| ax[0].imshow(img1) |
|
|
| ax[1].axis("off") |
| ax[1].imshow(img2) |
|
|
| plt.suptitle(title) |
| plt.show() |
|
|
|
|
| def show_three_images(img1, img2, img3, ax1_title="", ax2_title="", ax3_title="", title="Three images"): |
| """Displays a triplet of images.""" |
| fig, ax = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True) |
|
|
| ax[0].axis("off") |
| ax[0].imshow(img1) |
| ax[0].set_title(ax1_title) |
|
|
| ax[1].axis("off") |
| ax[1].imshow(img2) |
| ax[1].set_title(ax2_title) |
|
|
| ax[2].axis("off") |
| ax[2].imshow(img3) |
| ax[2].set_title(ax3_title) |
|
|
| plt.suptitle(title) |
| plt.show() |
|
|
|
|
| class KeypointMatcher: |
| """Class for Keypoint matching for a pair of images.""" |
|
|
| def __init__(self, **sift_args) -> None: |
| self.SIFT = cv2.SIFT_create(**sift_args) |
| self.BFMatcher = cv2.BFMatcher() |
| |
| @staticmethod |
| def _check_images(img1: np.ndarray, img2: np.ndarray): |
| assert isinstance(img1, np.ndarray) |
| assert len(img1.shape) == 2 |
|
|
| assert isinstance(img2, np.ndarray) |
| assert len(img2.shape) == 2 |
|
|
| |
| |
| @staticmethod |
| def _show_matches(img1, kp1, img2, kp2, matches, K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0))): |
| """Displays matches found in the image""" |
| selected_matches = np.random.choice(matches, K) |
| img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args) |
| show_single_image(img3, figsize=figsize, title=f"Randomly selected K = {K} matches between the pair of images.") |
| return img3 |
|
|
| def match(self, img1: PIL.Image, img2: PIL.Image, show_matches: bool = True): |
| """Finds, describes and matches keypoints in given pair of images.""" |
| |
| img1 = np.array(img1) |
| img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY) |
| |
| img2 = np.array(img2) |
| img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) |
| |
| |
| self._check_images(img1, img2) |
|
|
| |
| kp1, des1 = self.SIFT.detectAndCompute(img1, None) |
| kp2, des2 = self.SIFT.detectAndCompute(img2, None) |
|
|
| |
| matches = self.BFMatcher.match(des1, des2) |
|
|
| |
| matches = sorted(matches, key = lambda x:x.distance) |
|
|
| if show_matches: |
| self._show_matches(img1, kp1, img2, kp2, matches) |
|
|
| return matches, kp1, des1, kp2, des2 |
|
|
|
|
| def warp(im, M, output_shape): |
| out = np.zeros((output_shape[0], output_shape[1])) |
| for i in range(output_shape[0]): |
| for j in range(output_shape[1]): |
| u, v = np.array([[i, j, 0, 0, 1, 0], [0, 0, i, j, 0, 1]]) @ M |
| u = int(round(u)) |
| v = int(round(v)) |
| if im.shape[0] > u >= 0 and im.shape[1] > v >= 0: |
| out[i, j] = im[u, v] |
|
|
| return out |
|
|
|
|
| def project_2d_to_6d(X: np.ndarray): |
| """Projects X (N x 2) to Z (2N x 6) space.""" |
| N = len(X) |
| assert X.shape == (N, 2) |
|
|
| Z = np.zeros((2 * N, 6)) |
| |
| Z[::2, 0:2] = X |
| Z[::2, 4] = 1.0 |
| |
| Z[1::2, 2:4] = X |
| Z[1::2, 5] = 1.0 |
|
|
| return Z |
|
|
|
|
| def project_6d_to_2d(Z: np.ndarray): |
| """Projects Z (2N x 6) to X (N x 2) space.""" |
| N = len(Z) // 2 |
| assert Z.shape == (2 * N, 6) |
|
|
| X_from_even_rows = Z[::2, 0:2] |
| X_from_odd_rows = Z[1::2, 2:4] |
| assert (X_from_even_rows == X_from_odd_rows).all() |
|
|
| return X_from_even_rows |
|
|
|
|
|
|
| def project_2d_to_1d(X: np.ndarray): |
| """Returns X (N x 2) from Z (2N, 1)""" |
| N = len(X) |
| X_stretched = np.zeros(2 * N) |
| X_stretched[::2] = X[:, 0] |
| X_stretched[1::2] = X[:, 1] |
| return X_stretched |
|
|
|
|
| def project_1d_to_2d(Z: np.ndarray): |
| """Returns X (N x 2) from Z (2N, 1)""" |
| N = len(Z) // 2 |
| assert Z.shape == (2 * N,) |
|
|
| X = np.zeros((N, 2)) |
| X[:, 0] = Z[::2] |
| X[:, 1] = Z[1::2] |
|
|
| return X |
|
|
|
|
| def rigid_body_transform(X: np.ndarray, params: np.ndarray): |
| """Performs rigid body transformation of points X (N x 2) using params (6 x 1 flattened)""" |
| N = len(X) |
| assert X.shape == (N, 2) |
|
|
| X = project_2d_to_6d(X) |
|
|
| X_transformed = np.matmul(X, params) |
| X_transformed = project_1d_to_2d(X_transformed) |
| assert X_transformed.shape == (N, 2) |
|
|
| return X_transformed |
|
|
|
|
| def rigid_body_transform_params(X1: np.ndarray, X2: np.ndarray): |
| """Returns rigid-body transform parameters RT (6 x 1) assuming transformation between X1 and X2""" |
| N = len(X1) |
| assert X1.shape == X2.shape |
| assert X1.shape == (N, 2) |
|
|
| |
| X1_expanded = project_2d_to_6d(X1) |
| assert X1_expanded.shape == (2 * N, 6) |
|
|
| X2_stretched = project_2d_to_1d(X2) |
| assert X2_stretched.shape == (2 * N,) |
|
|
| params = np.dot(np.linalg.pinv(X1_expanded), X2_stretched) |
| return params |
|
|
|
|
| class ImageAlignment: |
| """Class to perform alignment of a pair of images given keypoints.""" |
|
|
| def __init__(self) -> None: |
| pass |
| |
| @staticmethod |
| def show_transformed_points(img1, img2, X1, kp1, kp2, matches, params, num_inliers, num_to_show=20): |
| import matplotlib.cm as cm |
|
|
| H1, W1 = img1.shape |
| H2, W2 = img2.shape |
| img = np.hstack([img1, img2]) |
|
|
| random_matches = np.random.choice(matches, num_to_show) |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(15, 6)) |
| colors = cm.rainbow(np.linspace(0, 1, num_to_show)) |
|
|
| for i, match in enumerate(random_matches): |
|
|
| |
| x1, y1 = kp1[match.queryIdx].pt |
| x2, y2 = kp2[match.trainIdx].pt |
|
|
| |
| A = project_2d_to_6d(np.array([[x1, y1]])) |
| (x1_transformed, y1_transformed) = np.dot(A, params) |
|
|
| ax.imshow(img, cmap="gray") |
| ax.axis("off") |
| ax.scatter(x1_transformed + W1, y1_transformed, s=200, marker="x", color=colors[i]) |
| ax.plot( |
| (x1, x1_transformed + W1), (y1, y1_transformed), |
| linestyle="--", color=colors[i], marker="o", |
| ) |
|
|
| ax.set_title( |
| f"Points in image 1 mapped to transformed points estimated by {num_inliers} points.", |
| fontsize=18, |
| ) |
|
|
| os.makedirs("./results/", exist_ok=True) |
| plt.savefig(f"./results/match_transformed_inliers_{num_inliers}.png", bbox_inches="tight") |
| plt.show() |
|
|
| def ransac( |
| self, img1, kp1, img2, kp2, matches, num_matches=6, max_iter=500, |
| radius_in_px=10, show_transformed=True, inlier_th_for_show=1000 |
| ): |
| """Performs RANSAC to find best matches.""" |
|
|
| best_inlier_count = 0 |
| best_params = None |
|
|
| |
| X1 = np.array([kp1[matches[i].queryIdx].pt for i in range(len(matches))]) |
|
|
| |
| X2 = np.array([kp2[matches[i].trainIdx].pt for i in range(len(matches))]) |
|
|
| for i in range(max_iter): |
| |
| selected_matches = np.random.choice(matches, num_matches) |
|
|
| |
| X1_selected = np.array([kp1[selected_matches[i].queryIdx].pt for i in range(len(selected_matches))]) |
|
|
| |
| X2_selected = np.array([kp2[selected_matches[i].trainIdx].pt for i in range(len(selected_matches))]) |
|
|
| |
| params = rigid_body_transform_params(X1_selected, X2_selected) |
| |
| |
| X2_transformed = rigid_body_transform(X1, params) |
|
|
| |
| diff = np.linalg.norm(X2_transformed - X2, axis=1) |
| indices = diff < radius_in_px |
| num_inliers = sum(indices) |
| if num_inliers > best_inlier_count: |
| print(f"Found {num_inliers} inliers!") |
| best_params = params |
| best_inlier_count = num_inliers |
|
|
| if show_transformed and num_inliers > inlier_th_for_show: |
| self.show_transformed_points(img1, img2, X1, kp1, kp2, matches, best_params, num_inliers) |
|
|
| return best_params |
| |
| def align( |
| self, img1, kp1, img2, kp2, matches, num_matches=6, |
| max_iter=500, show_warped_image=True, |
| save_warped=False, path="results/sample.png", |
| method="custom" |
| ): |
| best_params = self.ransac(img1, kp1, img2, kp2, matches, max_iter=max_iter, num_matches=num_matches) |
|
|
| |
| rows, cols = img1.shape[:2] |
|
|
| if method == 'custom': |
| img1_warped = warp(img1, best_params, (rows, cols)) |
| else: |
| M = np.zeros((2, 3)) |
| M[0, :2] = best_params[:2] |
| M[1, :2] = best_params[2:4] |
| M[0, 2] = best_params[4] |
| M[1, 2] = best_params[5] |
| img1_warped = cv2.warpAffine(img1, M, (cols, rows)) |
|
|
| if show_warped_image: |
| show_three_images( |
| img1, img2, img1_warped, title="", |
| ax1_title="Image 1", ax2_title="Image 2", ax3_title="Transformation: Image 1 to Image 2", |
| ) |
|
|
| if save_warped: |
| plt.imsave(path, img1_warped) |
|
|
| return best_params |
|
|
|
|
| if __name__ == "__main__": |
| |
| boat1 = cv2.imread('boat1.pgm', cv2.IMREAD_GRAYSCALE) |
| boat2 = cv2.imread('boat2.pgm', cv2.IMREAD_GRAYSCALE) |
| show_two_images(boat1, boat2, title="Given pair of images.") |
|
|
| kp_matcher = KeypointMatcher(contrastThreshold=0.1, edgeThreshold=5) |
| matches, kp1, des1, kp2, des2 = kp_matcher.match(boat1, boat2, show_matches=True) |