import cv2
import numpy as np
import gradio as gr
import torch
import matplotlib.cm as cm
import sys
sys.path.append("src")
from src.utils.plotting import make_matching_figure
from src.edm import EDM
from src.config.default import get_cfg_defaults
from src.utils.misc import lower_config
HEADER = """
"""
ABSTRACT = """
Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications."""
def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000):
config = get_cfg_defaults()
data_cfg_path = "configs/data/megadepth_test_1500.py"
main_cfg_path = "configs/edm/outdoor/edm_base.py"
config.merge_from_file(main_cfg_path)
config.merge_from_file(data_cfg_path)
W, H = 832, 832
config.EDM.COARSE.MCONF_THR = conf_thres
config.EDM.COARSE.BORDER_RM = border_rm
config.EDM.COARSE.TOPK = topk
_config = lower_config(config)
matcher = EDM(config=_config["edm"])
state_dict = torch.load("weights/edm_outdoor.ckpt", map_location=torch.device('cpu'))["state_dict"]
matcher.load_state_dict(state_dict)
matcher = matcher.eval()
# Load example images
img0_bgr = image_0
img1_bgr = image_1
h0, w0 = img0_bgr.shape[:2]
h1, w1 = img1_bgr.shape[:2]
h0_scale = h0 / H
w0_scale = w0 / W
h1_scale = h1 / H
w1_scale = w1 / W
# For inference
img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY)
img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY)
img0_raw = cv2.resize(img0_raw, (W, H))
img1_raw = cv2.resize(img1_raw, (W, H))
img0 = torch.from_numpy(img0_raw)[None][None] / 255.
img1 = torch.from_numpy(img1_raw)[None][None]/ 255.
batch = {'image0': img0, 'image1': img1}
# Inference with EDM and get prediction
with torch.no_grad():
matcher(batch)
mkpts0 = batch['mkpts0_f'].numpy()
mkpts1 = batch['mkpts1_f'].numpy()
mconf = batch['mconf'].numpy()
mkpts0[:, 0] *= w0_scale
mkpts0[:, 1] *= h0_scale
mkpts1[:, 0] *= w1_scale
mkpts1[:, 1] *= h1_scale
color = cm.jet(mconf)
# Draw
text = [
'EDM',
'Matches: {}'.format(len(mkpts0)),
]
fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text)
return fig
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Accordion("Abstract (click to open)", open=False):
gr.Image("assets/teaser.jpg")
gr.Markdown(ABSTRACT)
with gr.Row():
image_1 = gr.Image()
image_2 = gr.Image()
with gr.Row():
conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold")
topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK")
border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)")
gr.HTML(
"""
Note: images are actually resized to 832 x 832 for matching.
"""
)
with gr.Row():
button = gr.Button(value="Find Matches")
clear = gr.ClearButton(value="Clear")
output = gr.Image()
button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output)
clear.add([image_1, image_2, output])
gr.Examples(
examples=[
["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"],
["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"],
["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"],
["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"],
],
inputs=[image_1, image_2],
outputs=[output],
fn=find_matches,
cache_examples=None,
)
if __name__ == "__main__":
demo.launch()