File size: 5,631 Bytes
7e31006
 
bc6ea96
 
7e31006
 
 
bc6ea96
7e31006
bc6ea96
7e31006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce956c
7e31006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce956c
 
 
7e31006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce956c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import cv2
import numpy as np
import gradio as gr

import torch
import matplotlib.cm as cm
import sys

sys.path.append("src")

from src.utils.plotting import make_matching_figure
from src.edm import EDM
from src.config.default import get_cfg_defaults
from src.utils.misc import lower_config

    
HEADER = """
<div align="center">
    <p>
        <span style="font-size: 30px; vertical-align: bottom;"> 🎶 EDM: Efficient Deep Feature Matching</span>
    </p>
    <p style="margin-top: -15px;">
        <a href="https://arxiv.org/pdf/2503.05122" target="_blank" style="color: grey;">ArXiv Paper</a>
        &nbsp;
        <a href="https://github.com/chicleee/EDM" target="_blank" style="color: grey;">GitHub Repository</a>
    </p>

</div>
"""

ABSTRACT = """
Recent feature matching methods have achieved remarkable performance but lack efficiency consideration. In this paper, we revisit the mainstream detector-free matching pipeline and improve all its stages considering both accuracy and efficiency. We propose an Efficient Deep feature Matching network, EDM. We first adopt a deeper CNN with fewer dimensions to extract multi-level features. Then we present a Correlation Injection Module that conducts feature transformation on high-level deep features, and progressively injects feature correlations from global to local for efficient multi-scale feature aggregation, improving both speed and performance. In the refinement stage, a novel lightweight bidirectional axis-based regression head is designed to directly predict subpixel-level correspondences from latent features, avoiding the significant computational cost of explicitly locating keypoints on high-resolution local feature heatmaps. Moreover, effective selection strategies are introduced to enhance matching accuracy. Extensive experiments show that our EDM achieves competitive matching accuracy on various benchmarks and exhibits excellent efficiency, offering valuable best practices for real-world applications."""

def find_matches(image_0, image_1, conf_thres=0.2, border_rm=2, topk=10000):
    config = get_cfg_defaults()
    data_cfg_path = "configs/data/megadepth_test_1500.py"
    main_cfg_path = "configs/edm/outdoor/edm_base.py"
    config.merge_from_file(main_cfg_path)
    config.merge_from_file(data_cfg_path)

    W, H = 832, 832
    config.EDM.COARSE.MCONF_THR = conf_thres
    config.EDM.COARSE.BORDER_RM = border_rm
    config.EDM.COARSE.TOPK = topk

    _config = lower_config(config)
    matcher = EDM(config=_config["edm"])
    state_dict = torch.load("weights/edm_outdoor.ckpt", map_location=torch.device('cpu'))["state_dict"]
    matcher.load_state_dict(state_dict)
    matcher = matcher.eval()
    
    # Load example images
    img0_bgr = image_0
    img1_bgr = image_1


    h0, w0 = img0_bgr.shape[:2]
    h1, w1 = img1_bgr.shape[:2]

    h0_scale = h0 / H
    w0_scale = w0 / W
    h1_scale = h1 / H
    w1_scale = w1 / W

    # For inference
    img0_raw = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2GRAY)
    img1_raw = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2GRAY)
    img0_raw = cv2.resize(img0_raw, (W, H))
    img1_raw = cv2.resize(img1_raw, (W, H))

    img0 = torch.from_numpy(img0_raw)[None][None] / 255.
    img1 = torch.from_numpy(img1_raw)[None][None]/ 255.
    batch = {'image0': img0, 'image1': img1}

    # Inference with EDM and get prediction
    with torch.no_grad():
        matcher(batch)

    mkpts0 = batch['mkpts0_f'].numpy()
    mkpts1 = batch['mkpts1_f'].numpy()
    mconf = batch['mconf'].numpy()
    
    mkpts0[:, 0] *= w0_scale
    mkpts0[:, 1] *= h0_scale
    mkpts1[:, 0] *= w1_scale
    mkpts1[:, 1] *= h1_scale
    
    color = cm.jet(mconf)
    # Draw
    text = [
        'EDM',
        'Matches: {}'.format(len(mkpts0)),
    ]
    fig = make_matching_figure(img0_bgr, img1_bgr, mkpts0, mkpts1, color, text=text)
    
    return fig


with gr.Blocks() as demo:

    gr.Markdown(HEADER)
    with gr.Accordion("Abstract (click to open)", open=False):
        gr.Image("assets/teaser.jpg")
        gr.Markdown(ABSTRACT)

    with gr.Row():
        image_1 = gr.Image()
        image_2 = gr.Image()
    with gr.Row():
        conf_thres = gr.Slider(minimum=0.01, maximum=1, value=0.2, step=0.01, label="Coarse Confidence Threshold")
        topk = gr.Slider(minimum=100, maximum=10000, value=5000, step=100, label="TopK")
        border_rm = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="Border Remove (x8)")
    gr.HTML(
        """
        Note: images are actually resized to 832 x 832 for matching.
        """
    )
    with gr.Row():
        button = gr.Button(value="Find Matches")
        clear = gr.ClearButton(value="Clear")
        
    output = gr.Image()
    button.click(find_matches, [image_1, image_2, conf_thres, border_rm, topk], output)
    clear.add([image_1, image_2, output])

    gr.Examples(
        examples=[
            ["assets/scannet_sample_images/scene0707_00_15.jpg", "assets/scannet_sample_images/scene0707_00_45.jpg"],
            ["assets/scannet_sample_images/scene0758_00_165.jpg", "assets/scannet_sample_images/scene0758_00_510.jpg"],
            ["assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg", "assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg"],
            ["assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg", "assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"],
        ],
        inputs=[image_1, image_2],
        outputs=[output],
        fn=find_matches,
        cache_examples=None,
    )

if __name__ == "__main__":
    demo.launch()