Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from matplotlib.figure import Figure | |
| from numpy import ndarray | |
| import visualize | |
| CSS = """ | |
| #desc, #desc * { | |
| text-align: center !important; | |
| justify-content: center !important; | |
| align-items: center !important; | |
| } | |
| """ | |
| DESCRIPTION = """ | |
| <div align="center"> | |
| <h1><ins>MapGlue</ins> 🗺️</h1> | |
| <h2> | |
| MapGlue: Multimodal Remote Sensing Image Matching | |
| </h2> | |
| <p> | |
| Advanced feature matching system supporting various image modalities including:<br> | |
| SAR-Visible, Map-Visible, Depth-Visible, Infrared-Visible, Day-Night matching | |
| </p> | |
| </div> | |
| """ | |
| examples = [ | |
| [ | |
| "assets/day-night/L1.png", | |
| "assets/day-night/R1.png", | |
| ], | |
| [ | |
| "assets/day-night/L2.png", | |
| "assets/day-night/R2.png", | |
| ], | |
| [ | |
| "assets/depth-visible/L1.jpg", | |
| "assets/depth-visible/R1.jpg", | |
| ], | |
| [ | |
| "assets/depth-visible/L2.png", | |
| "assets/depth-visible/R2.png", | |
| ], | |
| [ | |
| "assets/infrared-visible/L1.png", | |
| "assets/infrared-visible/R1.png", | |
| ], | |
| [ | |
| "assets/infrared-visible/L2.png", | |
| "assets/infrared-visible/R2.png", | |
| ], | |
| [ | |
| "assets/map-visible/L1.jpg", | |
| "assets/map-visible/R1.jpg", | |
| ], | |
| [ | |
| "assets/map-visible/L2.png", | |
| "assets/map-visible/R2.png", | |
| ], | |
| [ | |
| "assets/sar-visible/L1.jpg", | |
| "assets/sar-visible/R1.jpg", | |
| ], | |
| [ | |
| "assets/sar-visible/L2.jpg", | |
| "assets/sar-visible/R2.jpg", | |
| ], | |
| [ | |
| "assets/sar-visible/L3.png", | |
| "assets/sar-visible/R3.png", | |
| ], | |
| ] | |
| def fig_to_ndarray(fig: Figure) -> ndarray: | |
| """Convert matplotlib figure to numpy array.""" | |
| fig.canvas.draw() | |
| w, h = fig.canvas.get_width_height() | |
| buffer = fig.canvas.buffer_rgba() | |
| out = np.frombuffer(buffer, dtype=np.uint8).reshape(h, w, 4) | |
| return out | |
| def load_mapglue_model(): | |
| """Load the MapGlue TorchScript model.""" | |
| # device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| device = 'cpu' | |
| model_path = './weights/fastmapglue_model.pt' | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError( | |
| f"Model file not found: {model_path}\n" | |
| f"Please ensure the HF_TOKEN environment variable is set to download the model." | |
| ) | |
| model = torch.jit.load(model_path, map_location=device) | |
| model.eval() | |
| model.to(device) | |
| return model, device | |
| def run_mapglue_matching( | |
| path0: str, | |
| path1: str, | |
| model_name: str, | |
| num_keypoints: int, | |
| ransac_threshold: float, | |
| ) -> Tuple[ndarray, ndarray, ndarray, ndarray]: | |
| """ | |
| Run MapGlue matching on two input images using Homography RANSAC. | |
| Args: | |
| path0, path1: Paths to input images | |
| model_name: Name of the matching model (currently supports FastMapGlue) | |
| num_keypoints: Number of keypoints to extract | |
| ransac_threshold: RANSAC reprojection threshold | |
| Returns: | |
| Tuple of (raw_keypoint_fig, raw_matching_fig, ransac_keypoint_fig, ransac_matching_fig) | |
| """ | |
| try: | |
| # Load model | |
| model, device = load_mapglue_model() | |
| # Load and preprocess images | |
| image0 = cv2.imread(path0) | |
| image1 = cv2.imread(path1) | |
| if image0 is None or image1 is None: | |
| raise ValueError("Could not load one or both images") | |
| # Convert BGR to RGB | |
| image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB) | |
| image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) | |
| # Convert to torch tensors | |
| image0_tensor = torch.from_numpy(image0).to(device) | |
| image1_tensor = torch.from_numpy(image1).to(device) | |
| num_keypoints_tensor = torch.tensor(num_keypoints).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| points_tensor = model(image0_tensor, image1_tensor, num_keypoints_tensor) | |
| points0 = points_tensor[:, :2] | |
| points1 = points_tensor[:, 2:] | |
| # Create raw matching visualization | |
| plt.figure(figsize=(12, 6)) | |
| axes = visualize.show_images([image0, image1]) | |
| visualize.draw_matches(points0, points1, line_colors="lime", line_width=0.8) | |
| visualize.add_text(0, f'Raw matches: {len(points0)}', font_size=16) | |
| raw_matching_fig = fig_to_ndarray(plt.gcf()) | |
| # Create raw keypoints visualization | |
| plt.figure(figsize=(12, 6)) | |
| axes = visualize.show_images([image0, image1]) | |
| visualize.draw_keypoints([points0.cpu().numpy(), points1.cpu().numpy()], | |
| kp_color=["lime", "lime"], kp_size=20) | |
| visualize.add_text(0, f'Raw keypoints: {len(points0)}', font_size=16) | |
| raw_keypoint_fig = fig_to_ndarray(plt.gcf()) | |
| # Apply RANSAC filtering | |
| points0_np = points0.cpu().numpy() | |
| points1_np = points1.cpu().numpy() | |
| H_pred, inlier_mask = cv2.findHomography( | |
| points0_np, points1_np, | |
| cv2.USAC_MAGSAC, | |
| ransacReprojThreshold=ransac_threshold, | |
| maxIters=10000, | |
| confidence=0.9999 | |
| ) | |
| if inlier_mask is not None and inlier_mask.sum() > 0: | |
| inlier_mask = inlier_mask.ravel() > 0 | |
| mkpts0 = points0_np[inlier_mask] | |
| mkpts1 = points1_np[inlier_mask] | |
| # Create RANSAC matching visualization | |
| plt.figure(figsize=(12, 6)) | |
| axes = visualize.show_images([image0, image1]) | |
| visualize.draw_matches(mkpts0, mkpts1, line_colors="lime", line_width=1) | |
| visualize.add_text(0, f'RANSAC matches @{ransac_threshold}px: {len(mkpts0)}/{len(points0)}', font_size=16) | |
| ransac_matching_fig = fig_to_ndarray(plt.gcf()) | |
| # Create RANSAC keypoints visualization | |
| plt.figure(figsize=(12, 6)) | |
| axes = visualize.show_images([image0, image1]) | |
| visualize.draw_keypoints([mkpts0, mkpts1], | |
| kp_color=["lime", "lime"], kp_size=20) | |
| visualize.add_text(0, f'RANSAC keypoints @{ransac_threshold}px: {len(mkpts0)}', font_size=16) | |
| ransac_keypoint_fig = fig_to_ndarray(plt.gcf()) | |
| else: | |
| # No inliers found | |
| ransac_matching_fig = None | |
| ransac_keypoint_fig = None | |
| plt.close('all') # Clean up matplotlib figures | |
| return ( | |
| raw_keypoint_fig, | |
| raw_matching_fig, | |
| ransac_keypoint_fig, | |
| ransac_matching_fig, | |
| ) | |
| except Exception as e: | |
| print(f"Error in matching: {str(e)}") | |
| # Return empty arrays in case of error | |
| empty_img = np.zeros((400, 800, 4), dtype=np.uint8) | |
| return (empty_img, empty_img, empty_img, empty_img) | |
| with gr.Blocks(css=CSS) as demo: | |
| with gr.Tab("Image Matching"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.HTML(DESCRIPTION, elem_id="desc") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input Panels:") | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| choices=["FastMapGlue"], | |
| value="FastMapGlue", | |
| label="Matching Model", | |
| ) | |
| with gr.Row(): | |
| path0 = gr.Image( | |
| height=300, | |
| image_mode="RGB", | |
| type="filepath", | |
| label="Image 0", | |
| ) | |
| path1 = gr.Image( | |
| height=300, | |
| image_mode="RGB", | |
| type="filepath", | |
| label="Image 1", | |
| ) | |
| with gr.Row(): | |
| stop = gr.Button(value="Stop", variant="stop") | |
| run = gr.Button(value="Run", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Accordion("Matching Settings"): | |
| with gr.Row(): | |
| num_keypoints = gr.Slider( | |
| minimum=512, | |
| maximum=4096, | |
| value=2048, | |
| step=256, | |
| label="Number of Keypoints", | |
| ) | |
| with gr.Accordion("RANSAC Settings"): | |
| with gr.Row(): | |
| ransac_threshold = gr.Slider( | |
| minimum=0.5, | |
| maximum=10.0, | |
| value=5.0, | |
| step=0.5, | |
| label="RANSAC Threshold", | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion("Example Pairs"): | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[path0, path1], | |
| label="Click an example pair below", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| "### Output Panels" | |
| ) | |
| with gr.Accordion("Raw Keypoints", open=False): | |
| raw_keypoint_fig = gr.Image( | |
| format="png", type="numpy", label="Raw Keypoints" | |
| ) | |
| with gr.Accordion("Raw Matches"): | |
| raw_matching_fig = gr.Image( | |
| format="png", type="numpy", label="Raw Matches" | |
| ) | |
| with gr.Accordion("RANSAC Keypoints", open=False): | |
| ransac_keypoint_fig = gr.Image( | |
| format="png", type="numpy", label="RANSAC Keypoints" | |
| ) | |
| with gr.Accordion("RANSAC Matches"): | |
| ransac_matching_fig = gr.Image( | |
| format="png", type="numpy", label="RANSAC Matches" | |
| ) | |
| inputs = [ | |
| path0, | |
| path1, | |
| model_name, | |
| num_keypoints, | |
| ransac_threshold, | |
| ] | |
| outputs = [ | |
| raw_keypoint_fig, | |
| raw_matching_fig, | |
| ransac_keypoint_fig, | |
| ransac_matching_fig, | |
| ] | |
| running_event = run.click( | |
| fn=run_mapglue_matching, inputs=inputs, outputs=outputs | |
| ) | |
| stop.click( | |
| fn=None, inputs=None, outputs=None, cancels=[running_event] | |
| ) | |
| if __name__ == "__main__": | |
| # Download model weights on startup if HF_TOKEN is available | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| model_path = './weights/fastmapglue_model.pt' | |
| if not os.path.exists(model_path): | |
| try: | |
| import requests | |
| # 使用 resolve 来直接下载文件 | |
| model_url = "https://huggingface.co/wupeihao/mapglue/resolve/main/fastmapglue_model.pt" | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| print("Downloading MapGlue model...") | |
| response = requests.get(model_url, headers=headers) | |
| response.raise_for_status() | |
| os.makedirs('./weights', exist_ok=True) | |
| with open(model_path, 'wb') as f: | |
| f.write(response.content) | |
| print("Model downloaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to download model: {str(e)}") | |
| demo.launch() | |