wupeihao commited on
Commit
a5b5add
·
verified ·
1 Parent(s): 32f97d7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +345 -0
  2. visualize.py +156 -0
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.figure import Figure
10
+ from numpy import ndarray
11
+ import visualize
12
+
13
+ CSS = """
14
+ #desc, #desc * {
15
+ text-align: center !important;
16
+ justify-content: center !important;
17
+ align-items: center !important;
18
+ }
19
+ """
20
+
21
+ DESCRIPTION = """
22
+ <div align="center">
23
+ <h1><ins>MapGlue</ins> 🗺️</h1>
24
+ <h2>
25
+ MapGlue: Multimodal Remote Sensing Image Matching
26
+ </h2>
27
+ <p>
28
+ Advanced feature matching system supporting various image modalities including:<br>
29
+ SAR-Visible, Map-Visible, Depth-Visible, Infrared-Visible, Day-Night matching
30
+ </p>
31
+ </div>
32
+ """
33
+
34
+ examples = [
35
+ [
36
+ "assets/day-night/L1.png",
37
+ "assets/day-night/R1.png",
38
+ ],
39
+ [
40
+ "assets/day-night/L2.png",
41
+ "assets/day-night/R2.png",
42
+ ],
43
+ [
44
+ "assets/depth-visible/L1.jpg",
45
+ "assets/depth-visible/R1.jpg",
46
+ ],
47
+ [
48
+ "assets/depth-visible/L2.png",
49
+ "assets/depth-visible/R2.png",
50
+ ],
51
+ [
52
+ "assets/infrared-visible/L1.png",
53
+ "assets/infrared-visible/R1.png",
54
+ ],
55
+ [
56
+ "assets/infrared-visible/L2.png",
57
+ "assets/infrared-visible/R2.png",
58
+ ],
59
+ [
60
+ "assets/map-visible/L1.jpg",
61
+ "assets/map-visible/R1.jpg",
62
+ ],
63
+ [
64
+ "assets/map-visible/L2.png",
65
+ "assets/map-visible/R2.png",
66
+ ],
67
+ [
68
+ "assets/sar-visible/L1.jpg",
69
+ "assets/sar-visible/R1.jpg",
70
+ ],
71
+ [
72
+ "assets/sar-visible/L2.jpg",
73
+ "assets/sar-visible/R2.jpg",
74
+ ],
75
+ [
76
+ "assets/sar-visible/L3.png",
77
+ "assets/sar-visible/R3.png",
78
+ ],
79
+ ]
80
+
81
+
82
+ def fig_to_ndarray(fig: Figure) -> ndarray:
83
+ """Convert matplotlib figure to numpy array."""
84
+ fig.canvas.draw()
85
+ w, h = fig.canvas.get_width_height()
86
+ buffer = fig.canvas.buffer_rgba()
87
+ out = np.frombuffer(buffer, dtype=np.uint8).reshape(h, w, 4)
88
+ return out
89
+
90
+ def load_mapglue_model():
91
+ """Load the MapGlue TorchScript model."""
92
+ # device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
93
+ device = 'cpu'
94
+ model_path = './weights/fastmapglue_model.pt'
95
+
96
+ if not os.path.exists(model_path):
97
+ raise FileNotFoundError(
98
+ f"Model file not found: {model_path}\n"
99
+ f"Please ensure the HF_TOKEN environment variable is set to download the model."
100
+ )
101
+
102
+ model = torch.jit.load(model_path, map_location=device)
103
+ model.eval()
104
+ model.to(device)
105
+ return model, device
106
+
107
+
108
+ def run_mapglue_matching(
109
+ path0: str,
110
+ path1: str,
111
+ model_name: str,
112
+ num_keypoints: int,
113
+ ransac_threshold: float,
114
+ ) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
115
+ """
116
+ Run MapGlue matching on two input images using Homography RANSAC.
117
+
118
+ Args:
119
+ path0, path1: Paths to input images
120
+ model_name: Name of the matching model (currently supports FastMapGlue)
121
+ num_keypoints: Number of keypoints to extract
122
+ ransac_threshold: RANSAC reprojection threshold
123
+
124
+ Returns:
125
+ Tuple of (raw_keypoint_fig, raw_matching_fig, ransac_keypoint_fig, ransac_matching_fig)
126
+ """
127
+ try:
128
+ # Load model
129
+ model, device = load_mapglue_model()
130
+
131
+ # Load and preprocess images
132
+ image0 = cv2.imread(path0)
133
+ image1 = cv2.imread(path1)
134
+
135
+ if image0 is None or image1 is None:
136
+ raise ValueError("Could not load one or both images")
137
+
138
+ # Convert BGR to RGB
139
+ image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
140
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
141
+
142
+ # Convert to torch tensors
143
+ image0_tensor = torch.from_numpy(image0).to(device)
144
+ image1_tensor = torch.from_numpy(image1).to(device)
145
+ num_keypoints_tensor = torch.tensor(num_keypoints).to(device)
146
+
147
+ # Run inference
148
+ with torch.no_grad():
149
+ points_tensor = model(image0_tensor, image1_tensor, num_keypoints_tensor)
150
+ points0 = points_tensor[:, :2]
151
+ points1 = points_tensor[:, 2:]
152
+
153
+ # Create raw matching visualization
154
+ plt.figure(figsize=(12, 6))
155
+ axes = visualize.show_images([image0, image1])
156
+ visualize.draw_matches(points0, points1, line_colors="lime", line_width=0.8)
157
+ visualize.add_text(0, f'Raw matches: {len(points0)}', font_size=16)
158
+ raw_matching_fig = fig_to_ndarray(plt.gcf())
159
+
160
+ # Create raw keypoints visualization
161
+ plt.figure(figsize=(12, 6))
162
+ axes = visualize.show_images([image0, image1])
163
+ visualize.draw_keypoints([points0.cpu().numpy(), points1.cpu().numpy()],
164
+ kp_color=["lime", "lime"], kp_size=20)
165
+ visualize.add_text(0, f'Raw keypoints: {len(points0)}', font_size=16)
166
+ raw_keypoint_fig = fig_to_ndarray(plt.gcf())
167
+
168
+ # Apply RANSAC filtering
169
+ points0_np = points0.cpu().numpy()
170
+ points1_np = points1.cpu().numpy()
171
+
172
+
173
+ H_pred, inlier_mask = cv2.findHomography(
174
+ points0_np, points1_np,
175
+ cv2.USAC_MAGSAC,
176
+ ransacReprojThreshold=ransac_threshold,
177
+ maxIters=10000,
178
+ confidence=0.9999
179
+ )
180
+
181
+ if inlier_mask is not None and inlier_mask.sum() > 0:
182
+ inlier_mask = inlier_mask.ravel() > 0
183
+ mkpts0 = points0_np[inlier_mask]
184
+ mkpts1 = points1_np[inlier_mask]
185
+
186
+ # Create RANSAC matching visualization
187
+ plt.figure(figsize=(12, 6))
188
+ axes = visualize.show_images([image0, image1])
189
+ visualize.draw_matches(mkpts0, mkpts1, line_colors="lime", line_width=1)
190
+ visualize.add_text(0, f'RANSAC matches @{ransac_threshold}px: {len(mkpts0)}/{len(points0)}', font_size=16)
191
+ ransac_matching_fig = fig_to_ndarray(plt.gcf())
192
+
193
+ # Create RANSAC keypoints visualization
194
+ plt.figure(figsize=(12, 6))
195
+ axes = visualize.show_images([image0, image1])
196
+ visualize.draw_keypoints([mkpts0, mkpts1],
197
+ kp_color=["lime", "lime"], kp_size=20)
198
+ visualize.add_text(0, f'RANSAC keypoints @{ransac_threshold}px: {len(mkpts0)}', font_size=16)
199
+ ransac_keypoint_fig = fig_to_ndarray(plt.gcf())
200
+ else:
201
+ # No inliers found
202
+ ransac_matching_fig = None
203
+ ransac_keypoint_fig = None
204
+
205
+ plt.close('all') # Clean up matplotlib figures
206
+
207
+ return (
208
+ raw_keypoint_fig,
209
+ raw_matching_fig,
210
+ ransac_keypoint_fig,
211
+ ransac_matching_fig,
212
+ )
213
+
214
+ except Exception as e:
215
+ print(f"Error in matching: {str(e)}")
216
+ # Return empty arrays in case of error
217
+ empty_img = np.zeros((400, 800, 4), dtype=np.uint8)
218
+ return (empty_img, empty_img, empty_img, empty_img)
219
+
220
+
221
+ with gr.Blocks(css=CSS) as demo:
222
+ with gr.Tab("Image Matching"):
223
+ with gr.Row():
224
+ with gr.Column(scale=3):
225
+ gr.HTML(DESCRIPTION, elem_id="desc")
226
+ with gr.Row():
227
+ with gr.Column():
228
+ gr.Markdown("### Input Panels:")
229
+ with gr.Row():
230
+ model_name = gr.Dropdown(
231
+ choices=["FastMapGlue"],
232
+ value="FastMapGlue",
233
+ label="Matching Model",
234
+ )
235
+ with gr.Row():
236
+ path0 = gr.Image(
237
+ height=300,
238
+ image_mode="RGB",
239
+ type="filepath",
240
+ label="Image 0",
241
+ )
242
+ path1 = gr.Image(
243
+ height=300,
244
+ image_mode="RGB",
245
+ type="filepath",
246
+ label="Image 1",
247
+ )
248
+ with gr.Row():
249
+ stop = gr.Button(value="Stop", variant="stop")
250
+ run = gr.Button(value="Run", variant="primary")
251
+
252
+ with gr.Accordion("Advanced Settings", open=False):
253
+ with gr.Accordion("Matching Settings"):
254
+ with gr.Row():
255
+ num_keypoints = gr.Slider(
256
+ minimum=512,
257
+ maximum=4096,
258
+ value=2048,
259
+ step=256,
260
+ label="Number of Keypoints",
261
+ )
262
+ with gr.Accordion("RANSAC Settings"):
263
+ with gr.Row():
264
+ ransac_threshold = gr.Slider(
265
+ minimum=0.5,
266
+ maximum=10.0,
267
+ value=5.0,
268
+ step=0.5,
269
+ label="RANSAC Threshold",
270
+ )
271
+
272
+ with gr.Row():
273
+ with gr.Accordion("Example Pairs"):
274
+ gr.Examples(
275
+ examples=examples,
276
+ inputs=[path0, path1],
277
+ label="Click an example pair below",
278
+ )
279
+
280
+ with gr.Column():
281
+ gr.Markdown(
282
+ "### Output Panels"
283
+ )
284
+ with gr.Accordion("Raw Keypoints", open=False):
285
+ raw_keypoint_fig = gr.Image(
286
+ format="png", type="numpy", label="Raw Keypoints"
287
+ )
288
+ with gr.Accordion("Raw Matches"):
289
+ raw_matching_fig = gr.Image(
290
+ format="png", type="numpy", label="Raw Matches"
291
+ )
292
+ with gr.Accordion("RANSAC Keypoints", open=False):
293
+ ransac_keypoint_fig = gr.Image(
294
+ format="png", type="numpy", label="RANSAC Keypoints"
295
+ )
296
+ with gr.Accordion("RANSAC Matches"):
297
+ ransac_matching_fig = gr.Image(
298
+ format="png", type="numpy", label="RANSAC Matches"
299
+ )
300
+
301
+ inputs = [
302
+ path0,
303
+ path1,
304
+ model_name,
305
+ num_keypoints,
306
+ ransac_threshold,
307
+ ]
308
+ outputs = [
309
+ raw_keypoint_fig,
310
+ raw_matching_fig,
311
+ ransac_keypoint_fig,
312
+ ransac_matching_fig,
313
+ ]
314
+
315
+ running_event = run.click(
316
+ fn=run_mapglue_matching, inputs=inputs, outputs=outputs
317
+ )
318
+ stop.click(
319
+ fn=None, inputs=None, outputs=None, cancels=[running_event]
320
+ )
321
+
322
+ if __name__ == "__main__":
323
+ # Download model weights on startup if HF_TOKEN is available
324
+ HF_TOKEN = os.getenv("HF_TOKEN")
325
+ if HF_TOKEN:
326
+ model_path = './weights/fastmapglue_model.pt'
327
+ if not os.path.exists(model_path):
328
+ try:
329
+ import requests
330
+ # 使用 resolve 来直接下载文件
331
+ model_url = "https://huggingface.co/wupeihao/mapglue/resolve/main/fastmapglue_model.pt"
332
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
333
+
334
+ print("Downloading MapGlue model...")
335
+ response = requests.get(model_url, headers=headers)
336
+ response.raise_for_status()
337
+
338
+ os.makedirs('./weights', exist_ok=True)
339
+ with open(model_path, 'wb') as f:
340
+ f.write(response.content)
341
+ print("Model downloaded successfully!")
342
+ except Exception as e:
343
+ print(f"Failed to download model: {str(e)}")
344
+
345
+ demo.launch()
visualize.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.patheffects as peffects
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import torch
6
+
7
+ def show_images(image_list, titles=None, colormaps="gray", dpi=100, pad=0.5, auto_size=True):
8
+ """
9
+ Display a set of images horizontally.
10
+
11
+ Args:
12
+ image_list: List of images in either NumPy RGB (H, W, 3),
13
+ PyTorch RGB (3, H, W) or grayscale (H, W) format.
14
+ titles: List of titles for each image.
15
+ colormaps: Colormap for grayscale images.
16
+ dpi: Figure resolution.
17
+ pad: Padding between images.
18
+ auto_size: Whether the figure size should adapt to the images' aspect ratios.
19
+ """
20
+ # Convert torch.Tensor images to NumPy arrays in (H, W, 3) format.
21
+ image_list = [
22
+ img.permute(1, 2, 0).cpu().numpy()
23
+ if (isinstance(img, torch.Tensor) and img.dim() == 3)
24
+ else img
25
+ for img in image_list
26
+ ]
27
+ num_imgs = len(image_list)
28
+ if not isinstance(colormaps, (list, tuple)):
29
+ colormaps = [colormaps] * num_imgs
30
+
31
+ if auto_size:
32
+ ratios = [im.shape[1] / im.shape[0] for im in image_list] # width / height
33
+ else:
34
+ ratios = [4 / 3] * num_imgs
35
+ fig_size = [sum(ratios) * 4.5, 4.5]
36
+ fig, axes = plt.subplots(1, num_imgs, figsize=fig_size, dpi=dpi, gridspec_kw={"width_ratios": ratios})
37
+ if num_imgs == 1:
38
+ axes = [axes]
39
+ for i in range(num_imgs):
40
+ axes[i].imshow(image_list[i], cmap=plt.get_cmap(colormaps[i]))
41
+ axes[i].set_xticks([])
42
+ axes[i].set_yticks([])
43
+ axes[i].set_axis_off()
44
+ for spine in axes[i].spines.values():
45
+ spine.set_visible(False)
46
+ if titles:
47
+ axes[i].set_title(titles[i])
48
+ fig.tight_layout(pad=pad)
49
+
50
+
51
+ def draw_keypoints(keypoints, kp_color="lime", kp_size=4, ax_list=None, alpha_value=1.0):
52
+ """
53
+ Plot keypoints on existing images.
54
+
55
+ Args:
56
+ keypoints: List of ndarrays (N, 2) for each set of keypoints.
57
+ kp_color: Color for keypoints, or list of colors for each set.
58
+ kp_size: Size of keypoints.
59
+ ax_list: List of axes to plot keypoints on; defaults to current figure's axes.
60
+ alpha_value: Opacity for keypoints.
61
+ """
62
+ if not isinstance(kp_color, list):
63
+ kp_color = [kp_color] * len(keypoints)
64
+ if not isinstance(alpha_value, list):
65
+ alpha_value = [alpha_value] * len(keypoints)
66
+ if ax_list is None:
67
+ ax_list = plt.gcf().axes
68
+ for ax, pts, color, alpha in zip(ax_list, keypoints, kp_color, alpha_value):
69
+ if isinstance(pts, torch.Tensor):
70
+ pts = pts.cpu().numpy()
71
+ ax.scatter(pts[:, 0], pts[:, 1], c=color, s=kp_size, linewidths=0, alpha=alpha)
72
+
73
+
74
+ def draw_matches(pts_left, pts_right, line_colors=None, line_width=1.5, endpoint_size=4, alpha_value=1.0, labels=None, axes_pair=None):
75
+ """
76
+ Draw matches between a pair of images.
77
+
78
+ Args:
79
+ pts_left, pts_right: Corresponding keypoints for the two images (N, 2).
80
+ line_colors: Colors for each match line, either as a string or an RGB tuple.
81
+ If not provided, random colors will be generated.
82
+ line_width: Width of the match lines.
83
+ endpoint_size: Size of the endpoints (if 0, endpoints are not drawn).
84
+ alpha_value: Opacity for the match lines.
85
+ labels: Optional list of labels for each match.
86
+ axes_pair: List of two axes [ax_left, ax_right] to plot the images; defaults to the first two axes in the current figure.
87
+ """
88
+ fig = plt.gcf()
89
+ if axes_pair is None:
90
+ axs = fig.axes
91
+ ax_left, ax_right = axs[0], axs[1]
92
+ else:
93
+ ax_left, ax_right = axes_pair
94
+ if isinstance(pts_left, torch.Tensor):
95
+ pts_left = pts_left.cpu().numpy()
96
+ if isinstance(pts_right, torch.Tensor):
97
+ pts_right = pts_right.cpu().numpy()
98
+ assert len(pts_left) == len(pts_right)
99
+ if line_colors is None:
100
+ line_colors = matplotlib.cm.hsv(np.random.rand(len(pts_left))).tolist()
101
+ elif len(line_colors) > 0 and not isinstance(line_colors[0], (tuple, list)):
102
+ line_colors = [line_colors] * len(pts_left)
103
+
104
+ if line_width > 0:
105
+ for i in range(len(pts_left)):
106
+ connector = matplotlib.patches.ConnectionPatch(
107
+ xyA=(pts_left[i, 0], pts_left[i, 1]),
108
+ xyB=(pts_right[i, 0], pts_right[i, 1]),
109
+ coordsA=ax_left.transData,
110
+ coordsB=ax_right.transData,
111
+ axesA=ax_left,
112
+ axesB=ax_right,
113
+ zorder=1,
114
+ color=line_colors[i],
115
+ linewidth=line_width,
116
+ clip_on=True,
117
+ alpha=alpha_value,
118
+ label=None if labels is None else labels[i],
119
+ picker=5.0,
120
+ )
121
+ connector.set_annotation_clip(True)
122
+ fig.add_artist(connector)
123
+
124
+ # Freeze axis autoscaling to prevent changes.
125
+ ax_left.autoscale(enable=False)
126
+ ax_right.autoscale(enable=False)
127
+
128
+ if endpoint_size > 0:
129
+ ax_left.scatter(pts_left[:, 0], pts_left[:, 1], c=line_colors, s=endpoint_size)
130
+ ax_right.scatter(pts_right[:, 0], pts_right[:, 1], c=line_colors, s=endpoint_size)
131
+
132
+
133
+ def add_text(axis_idx, text, pos=(0.01, 0.99), font_size=15, txt_color="w", border_color="k", border_width=2, h_align="left", v_align="top"):
134
+ """
135
+ Add an annotation with an outline to a specified axis.
136
+
137
+ Args:
138
+ axis_idx: Index of the axis in the current figure where the annotation will be added.
139
+ text: The annotation text.
140
+ pos: Position of the annotation in axis coordinates (e.g., (0.01, 0.99)).
141
+ font_size: Font size of the text.
142
+ txt_color: Text color.
143
+ border_color: Outline color (if None, no outline is applied).
144
+ border_width: Width of the outline.
145
+ h_align: Horizontal alignment (e.g., "left").
146
+ v_align: Vertical alignment (e.g., "top").
147
+ """
148
+ current_ax = plt.gcf().axes[axis_idx]
149
+ annotation = current_ax.text(
150
+ *pos, text, fontsize=font_size, ha=h_align, va=v_align, color=txt_color, transform=current_ax.transAxes
151
+ )
152
+ if border_color is not None:
153
+ annotation.set_path_effects([
154
+ peffects.Stroke(linewidth=border_width, foreground=border_color),
155
+ peffects.Normal(),
156
+ ])