import cv2 import numpy as np import gradio as gr import torch import matplotlib.cm as cm import sys sys.path.append("src") from src.utils.plotting import make_matching_figure from src.edm import EDM from src.config.default import get_cfg_defaults from src.utils.misc import lower_config HEADER = """

🎶 EDM: Efficient Deep Feature Matching

ArXiv Paper   GitHub Repository

""" ABSTRACT = """ Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications.""" def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000): config = get_cfg_defaults() data_cfg_path = "configs/data/megadepth_test_1500.py" main_cfg_path = "configs/edm/outdoor/edm_base.py" config.merge_from_file(main_cfg_path) config.merge_from_file(data_cfg_path) W, H = 832, 832 config.EDM.COARSE.MCONF_THR = conf_thres config.EDM.COARSE.BORDER_RM = border_rm config.EDM.COARSE.TOPK = topk _config = lower_config(config) matcher = EDM(config=_config["edm"]) state_dict = torch.load("weights/edm_outdoor.ckpt", map_location=torch.device('cpu'))["state_dict"] matcher.load_state_dict(state_dict) matcher = matcher.eval() # Load example images img0_bgr = image_0 img1_bgr = image_1 h0, w0 = img0_bgr.shape[:2] h1, w1 = img1_bgr.shape[:2] h0_scale = h0 / H w0_scale = w0 / W h1_scale = h1 / H w1_scale = w1 / W # For inference img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY) img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY) img0_raw = cv2.resize(img0_raw, (W, H)) img1_raw = cv2.resize(img1_raw, (W, H)) img0 = torch.from_numpy(img0_raw)[None][None] / 255. img1 = torch.from_numpy(img1_raw)[None][None]/ 255. batch = {'image0': img0, 'image1': img1} # Inference with EDM and get prediction with torch.no_grad(): matcher(batch) mkpts0 = batch['mkpts0_f'].numpy() mkpts1 = batch['mkpts1_f'].numpy() mconf = batch['mconf'].numpy() mkpts0[:, 0] *= w0_scale mkpts0[:, 1] *= h0_scale mkpts1[:, 0] *= w1_scale mkpts1[:, 1] *= h1_scale color = cm.jet(mconf) # Draw text = [ 'EDM', 'Matches: {}'.format(len(mkpts0)), ] fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text) return fig with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Accordion("Abstract (click to open)", open=False): gr.Image("assets/teaser.jpg") gr.Markdown(ABSTRACT) with gr.Row(): image_1 = gr.Image() image_2 = gr.Image() with gr.Row(): conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold") topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK") border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)") gr.HTML( """ Note: images are actually resized to 832 x 832 for matching. """ ) with gr.Row(): button = gr.Button(value="Find Matches") clear = gr.ClearButton(value="Clear") output = gr.Image() button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output) clear.add([image_1, image_2, output]) gr.Examples( examples=[ ["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"], ["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"], ["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"], ["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"], ], inputs=[image_1, image_2], outputs=[output], fn=find_matches, cache_examples=None, ) if __name__ == "__main__": demo.launch()