MapGlue / visualize.py
wupeihao's picture
Upload 2 files
a5b5add verified
import matplotlib
import matplotlib.patheffects as peffects
import matplotlib.pyplot as plt
import numpy as np
import torch
def show_images(image_list, titles=None, colormaps="gray", dpi=100, pad=0.5, auto_size=True):
"""
Display a set of images horizontally.
Args:
image_list: List of images in either NumPy RGB (H, W, 3),
PyTorch RGB (3, H, W) or grayscale (H, W) format.
titles: List of titles for each image.
colormaps: Colormap for grayscale images.
dpi: Figure resolution.
pad: Padding between images.
auto_size: Whether the figure size should adapt to the images' aspect ratios.
"""
# Convert torch.Tensor images to NumPy arrays in (H, W, 3) format.
image_list = [
img.permute(1, 2, 0).cpu().numpy()
if (isinstance(img, torch.Tensor) and img.dim() == 3)
else img
for img in image_list
]
num_imgs = len(image_list)
if not isinstance(colormaps, (list, tuple)):
colormaps = [colormaps] * num_imgs
if auto_size:
ratios = [im.shape[1] / im.shape[0] for im in image_list] # width / height
else:
ratios = [4 / 3] * num_imgs
fig_size = [sum(ratios) * 4.5, 4.5]
fig, axes = plt.subplots(1, num_imgs, figsize=fig_size, dpi=dpi, gridspec_kw={"width_ratios": ratios})
if num_imgs == 1:
axes = [axes]
for i in range(num_imgs):
axes[i].imshow(image_list[i], cmap=plt.get_cmap(colormaps[i]))
axes[i].set_xticks([])
axes[i].set_yticks([])
axes[i].set_axis_off()
for spine in axes[i].spines.values():
spine.set_visible(False)
if titles:
axes[i].set_title(titles[i])
fig.tight_layout(pad=pad)
def draw_keypoints(keypoints, kp_color="lime", kp_size=4, ax_list=None, alpha_value=1.0):
"""
Plot keypoints on existing images.
Args:
keypoints: List of ndarrays (N, 2) for each set of keypoints.
kp_color: Color for keypoints, or list of colors for each set.
kp_size: Size of keypoints.
ax_list: List of axes to plot keypoints on; defaults to current figure's axes.
alpha_value: Opacity for keypoints.
"""
if not isinstance(kp_color, list):
kp_color = [kp_color] * len(keypoints)
if not isinstance(alpha_value, list):
alpha_value = [alpha_value] * len(keypoints)
if ax_list is None:
ax_list = plt.gcf().axes
for ax, pts, color, alpha in zip(ax_list, keypoints, kp_color, alpha_value):
if isinstance(pts, torch.Tensor):
pts = pts.cpu().numpy()
ax.scatter(pts[:, 0], pts[:, 1], c=color, s=kp_size, linewidths=0, alpha=alpha)
def draw_matches(pts_left, pts_right, line_colors=None, line_width=1.5, endpoint_size=4, alpha_value=1.0, labels=None, axes_pair=None):
"""
Draw matches between a pair of images.
Args:
pts_left, pts_right: Corresponding keypoints for the two images (N, 2).
line_colors: Colors for each match line, either as a string or an RGB tuple.
If not provided, random colors will be generated.
line_width: Width of the match lines.
endpoint_size: Size of the endpoints (if 0, endpoints are not drawn).
alpha_value: Opacity for the match lines.
labels: Optional list of labels for each match.
axes_pair: List of two axes [ax_left, ax_right] to plot the images; defaults to the first two axes in the current figure.
"""
fig = plt.gcf()
if axes_pair is None:
axs = fig.axes
ax_left, ax_right = axs[0], axs[1]
else:
ax_left, ax_right = axes_pair
if isinstance(pts_left, torch.Tensor):
pts_left = pts_left.cpu().numpy()
if isinstance(pts_right, torch.Tensor):
pts_right = pts_right.cpu().numpy()
assert len(pts_left) == len(pts_right)
if line_colors is None:
line_colors = matplotlib.cm.hsv(np.random.rand(len(pts_left))).tolist()
elif len(line_colors) > 0 and not isinstance(line_colors[0], (tuple, list)):
line_colors = [line_colors] * len(pts_left)
if line_width > 0:
for i in range(len(pts_left)):
connector = matplotlib.patches.ConnectionPatch(
xyA=(pts_left[i, 0], pts_left[i, 1]),
xyB=(pts_right[i, 0], pts_right[i, 1]),
coordsA=ax_left.transData,
coordsB=ax_right.transData,
axesA=ax_left,
axesB=ax_right,
zorder=1,
color=line_colors[i],
linewidth=line_width,
clip_on=True,
alpha=alpha_value,
label=None if labels is None else labels[i],
picker=5.0,
)
connector.set_annotation_clip(True)
fig.add_artist(connector)
# Freeze axis autoscaling to prevent changes.
ax_left.autoscale(enable=False)
ax_right.autoscale(enable=False)
if endpoint_size > 0:
ax_left.scatter(pts_left[:, 0], pts_left[:, 1], c=line_colors, s=endpoint_size)
ax_right.scatter(pts_right[:, 0], pts_right[:, 1], c=line_colors, s=endpoint_size)
def add_text(axis_idx, text, pos=(0.01, 0.99), font_size=15, txt_color="w", border_color="k", border_width=2, h_align="left", v_align="top"):
"""
Add an annotation with an outline to a specified axis.
Args:
axis_idx: Index of the axis in the current figure where the annotation will be added.
text: The annotation text.
pos: Position of the annotation in axis coordinates (e.g., (0.01, 0.99)).
font_size: Font size of the text.
txt_color: Text color.
border_color: Outline color (if None, no outline is applied).
border_width: Width of the outline.
h_align: Horizontal alignment (e.g., "left").
v_align: Vertical alignment (e.g., "top").
"""
current_ax = plt.gcf().axes[axis_idx]
annotation = current_ax.text(
*pos, text, fontsize=font_size, ha=h_align, va=v_align, color=txt_color, transform=current_ax.transAxes
)
if border_color is not None:
annotation.set_path_effects([
peffects.Stroke(linewidth=border_width, foreground=border_color),
peffects.Normal(),
])