File size: 5,631 Bytes
7e31006 bc6ea96 7e31006 bc6ea96 7e31006 bc6ea96 7e31006 cce956c 7e31006 cce956c 7e31006 cce956c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import cv2
import numpy as np
import gradio as gr
import torch
import matplotlib.cm as cm
import sys
sys.path.append("src")
from src.utils.plotting import make_matching_figure
from src.edm import EDM
from src.config.default import get_cfg_defaults
from src.utils.misc import lower_config
HEADER = """
<div align="center">
<p>
<span style="font-size: 30px; vertical-align: bottom;"> 🎶 EDM: Efficient Deep Feature Matching</span>
</p>
<p style="margin-top: -15px;">
<a href="https://arxiv.org/pdf/2503.05122" target="_blank" style="color: grey;">ArXiv Paper</a>
<a href="https://github.com/chicleee/EDM" target="_blank" style="color: grey;">GitHub Repository</a>
</p>
</div>
"""
ABSTRACT = """
Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications."""
def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000):
config = get_cfg_defaults()
data_cfg_path = "configs/data/megadepth_test_1500.py"
main_cfg_path = "configs/edm/outdoor/edm_base.py"
config.merge_from_file(main_cfg_path)
config.merge_from_file(data_cfg_path)
W, H = 832, 832
config.EDM.COARSE.MCONF_THR = conf_thres
config.EDM.COARSE.BORDER_RM = border_rm
config.EDM.COARSE.TOPK = topk
_config = lower_config(config)
matcher = EDM(config=_config["edm"])
state_dict = torch.load("weights/edm_outdoor.ckpt", map_location=torch.device('cpu'))["state_dict"]
matcher.load_state_dict(state_dict)
matcher = matcher.eval()
# Load example images
img0_bgr = image_0
img1_bgr = image_1
h0, w0 = img0_bgr.shape[:2]
h1, w1 = img1_bgr.shape[:2]
h0_scale = h0 / H
w0_scale = w0 / W
h1_scale = h1 / H
w1_scale = w1 / W
# For inference
img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY)
img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY)
img0_raw = cv2.resize(img0_raw, (W, H))
img1_raw = cv2.resize(img1_raw, (W, H))
img0 = torch.from_numpy(img0_raw)[None][None] / 255.
img1 = torch.from_numpy(img1_raw)[None][None]/ 255.
batch = {'image0': img0, 'image1': img1}
# Inference with EDM and get prediction
with torch.no_grad():
matcher(batch)
mkpts0 = batch['mkpts0_f'].numpy()
mkpts1 = batch['mkpts1_f'].numpy()
mconf = batch['mconf'].numpy()
mkpts0[:, 0] *= w0_scale
mkpts0[:, 1] *= h0_scale
mkpts1[:, 0] *= w1_scale
mkpts1[:, 1] *= h1_scale
color = cm.jet(mconf)
# Draw
text = [
'EDM',
'Matches: {}'.format(len(mkpts0)),
]
fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text)
return fig
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Accordion("Abstract (click to open)", open=False):
gr.Image("assets/teaser.jpg")
gr.Markdown(ABSTRACT)
with gr.Row():
image_1 = gr.Image()
image_2 = gr.Image()
with gr.Row():
conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold")
topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK")
border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)")
gr.HTML(
"""
Note: images are actually resized to 832 x 832 for matching.
"""
)
with gr.Row():
button = gr.Button(value="Find Matches")
clear = gr.ClearButton(value="Clear")
output = gr.Image()
button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output)
clear.add([image_1, image_2, output])
gr.Examples(
examples=[
["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"],
["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"],
["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"],
["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"],
],
inputs=[image_1, image_2],
outputs=[output],
fn=find_matches,
cache_examples=None,
)
if __name__ == "__main__":
demo.launch() |