| import argparse |
| import os |
| from os.path import join |
| import sys |
| import numpy as np |
| import cv2 |
| import torch |
| from matplotlib import pyplot as plt |
| from tqdm import tqdm |
|
|
| from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT |
| from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches |
| from line_matching.two_view_pipeline import TwoViewPipeline |
|
|
| from scalelsd.base import show, WireframeGraph |
| from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography |
| from kornia.geometry import warp_perspective,transform_points |
|
|
| class HADConfig: |
| num_iter = 1 |
| valid_border_margin = 3 |
| translation = True |
| rotation = True |
| scale = True |
| perspective = True |
| scaling_amplitude = 0.2 |
| perspective_amplitude_x = 0.2 |
| perspective_amplitude_y = 0.2 |
| allow_artifacts = False |
| patch_ratio = 0.85 |
| had_cfg = HADConfig() |
|
|
| def sample_homographics(height, width): |
|
|
| def scale_homography(H, stride): |
| H_scaled = H.clone() |
| H_scaled[:, :, 2, :2] *= stride |
| H_scaled[:, :, :2, 2] /= stride |
| return H_scaled |
|
|
| homographic = sample_homography( |
| shape = (height, width), |
| perspective = had_cfg.perspective, |
| scaling = had_cfg.scale, |
| rotation = had_cfg.rotation, |
| translation = had_cfg.translation, |
| scaling_amplitude = had_cfg.scaling_amplitude, |
| perspective_amplitude_x = had_cfg.perspective_amplitude_x, |
| perspective_amplitude_y = had_cfg.perspective_amplitude_y, |
| patch_ratio = had_cfg.patch_ratio, |
| allow_artifacts = False |
| )[0] |
|
|
| homographic = torch.from_numpy(homographic[None]).float().cuda() |
| homographic_inv = torch.inverse(homographic) |
|
|
| H = { |
| 'h.1': homographic, |
| 'ih.1': homographic_inv, |
| } |
| |
| return H |
|
|
| def trans_image_with_homograpy(image): |
| h, w = image.shape[:2] |
| H = sample_homographics(height=h, width=w) |
|
|
| image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w)) |
| image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8) |
| plt.imshow(image_warped_) |
| plt.show() |
| return image_warped_ |
|
|
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser( |
| prog='GlueStick Demo', |
| description='Demo app to show the point and line matches obtained by GlueStick') |
| parser.add_argument('-img1', default='assets/figs/sa_1119229.jpg') |
| parser.add_argument('-img2', default=None) |
| parser.add_argument('--max_pts', type=int, default=1000) |
| parser.add_argument('--max_lines', type=int, default=300) |
| parser.add_argument('--model', type=str, default='models/paper-sa1b-997pkgs-model.pt') |
| args = parser.parse_args() |
|
|
| |
| if args.img1 is None and args.img2 is None: |
| raise ValueError("Input at least one path of image1 or image2") |
|
|
| |
| conf = { |
| 'name': 'two_view_pipeline', |
| 'use_lines': True, |
| 'extractor': { |
| 'name': 'wireframe', |
| 'sp_params': { |
| 'force_num_keypoints': False, |
| 'max_num_keypoints': args.max_pts, |
| }, |
| 'wireframe_params': { |
| 'merge_points': True, |
| 'merge_line_endpoints': True, |
| |
| }, |
| 'max_n_lines': args.max_lines, |
| }, |
| 'matcher': { |
| 'name': 'gluestick', |
| 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), |
| 'trainable': False, |
| }, |
| 'ground_truth': { |
| 'from_pose_depth': False, |
| } |
| } |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| pipeline_model = TwoViewPipeline(conf).to(device).eval() |
| pipeline_model.extractor.update_conf(None) |
| |
| saveto = f'temp_output/matching_results' |
| os.makedirs(saveto, exist_ok=True) |
|
|
| image1 = cv2.cvtColor(cv2.imread(args.img1), cv2.COLOR_BGR2RGB) |
| if args.img2 is None: |
| image2 = trans_image_with_homograpy(image1) |
| cv2.imwrite(f'{saveto}/warped_image.png', image2) |
| args.img2 = f'{saveto}/warped_image.png' |
|
|
| gray0 = cv2.imread(args.img1, 0) |
| gray1 = cv2.imread(args.img2, 0) |
|
|
| torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) |
| torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] |
|
|
| x = {'image0': torch_gray0, 'image1': torch_gray1} |
| pred = pipeline_model(x) |
|
|
| pred = batch_to_np(pred) |
| kp0, kp1 = pred["keypoints0"], pred["keypoints1"] |
| m0 = pred["matches0"] |
|
|
| line_seg0, line_seg1 = pred["lines0"], pred["lines1"] |
| line_matches = pred["line_matches0"] |
|
|
| valid_matches = m0 != -1 |
| match_indices = m0[valid_matches] |
| matched_kps0 = kp0[valid_matches] |
| matched_kps1 = kp1[match_indices] |
|
|
| valid_matches = line_matches != -1 |
| match_indices = line_matches[valid_matches] |
| matched_lines0 = line_seg0[valid_matches] |
| matched_lines1 = line_seg1[match_indices] |
|
|
| |
| gray0 = cv2.imread(args.img1, 0) |
| gray1 = cv2.imread(args.img2, 0) |
| img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) |
| |
| plot_images([img0, img1], dpi=200, pad=2.0) |
| plot_lines([line_seg0, line_seg1], ps=4, lw=2) |
| plt.gcf().canvas.manager.set_window_title('Detected Lines') |
| |
| plt.savefig(f'{saveto}/det.png') |
|
|
| plot_images([img0, img1], dpi=200, pad=2.0) |
| plot_color_line_matches([matched_lines0, matched_lines1], lw=3) |
| plt.gcf().canvas.manager.set_window_title('Line Matches') |
| |
| plt.savefig(f'{saveto}/mat.png') |
|
|
| whitebg = 1 |
| show.Canvas.white_overlay = whitebg |
| painter = show.painters.HAWPainter() |
|
|
| fig_file = f'{saveto}/det1.png' |
| outputs = {'lines_pred': line_seg0.reshape(-1,4)} |
| with show.image_canvas(args.img1, fig_file=fig_file) as ax: |
| |
| painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') |
| fig_file = f'{saveto}/det2.png' |
| outputs = {'lines_pred': line_seg1.reshape(-1,4)} |
| with show.image_canvas(args.img2, fig_file=fig_file) as ax: |
| painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') |
|
|
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|