import matplotlib import matplotlib.patheffects as peffects import matplotlib.pyplot as plt import numpy as np import torch def show_images(image_list, titles=None, colormaps="gray", dpi=100, pad=0.5, auto_size=True): """ Display a set of images horizontally. Args: image_list: List of images in either NumPy RGB (H, W, 3), PyTorch RGB (3, H, W) or grayscale (H, W) format. titles: List of titles for each image. colormaps: Colormap for grayscale images. dpi: Figure resolution. pad: Padding between images. auto_size: Whether the figure size should adapt to the images' aspect ratios. """ # Convert torch.Tensor images to NumPy arrays in (H, W, 3) format. image_list = [ img.permute(1, 2, 0).cpu().numpy() if (isinstance(img, torch.Tensor) and img.dim() == 3) else img for img in image_list ] num_imgs = len(image_list) if not isinstance(colormaps, (list, tuple)): colormaps = [colormaps] * num_imgs if auto_size: ratios = [im.shape[1] / im.shape[0] for im in image_list] # width / height else: ratios = [4 / 3] * num_imgs fig_size = [sum(ratios) * 4.5, 4.5] fig, axes = plt.subplots(1, num_imgs, figsize=fig_size, dpi=dpi, gridspec_kw={"width_ratios": ratios}) if num_imgs == 1: axes = [axes] for i in range(num_imgs): axes[i].imshow(image_list[i], cmap=plt.get_cmap(colormaps[i])) axes[i].set_xticks([]) axes[i].set_yticks([]) axes[i].set_axis_off() for spine in axes[i].spines.values(): spine.set_visible(False) if titles: axes[i].set_title(titles[i]) fig.tight_layout(pad=pad) def draw_keypoints(keypoints, kp_color="lime", kp_size=4, ax_list=None, alpha_value=1.0): """ Plot keypoints on existing images. Args: keypoints: List of ndarrays (N, 2) for each set of keypoints. kp_color: Color for keypoints, or list of colors for each set. kp_size: Size of keypoints. ax_list: List of axes to plot keypoints on; defaults to current figure's axes. alpha_value: Opacity for keypoints. """ if not isinstance(kp_color, list): kp_color = [kp_color] * len(keypoints) if not isinstance(alpha_value, list): alpha_value = [alpha_value] * len(keypoints) if ax_list is None: ax_list = plt.gcf().axes for ax, pts, color, alpha in zip(ax_list, keypoints, kp_color, alpha_value): if isinstance(pts, torch.Tensor): pts = pts.cpu().numpy() ax.scatter(pts[:, 0], pts[:, 1], c=color, s=kp_size, linewidths=0, alpha=alpha) def draw_matches(pts_left, pts_right, line_colors=None, line_width=1.5, endpoint_size=4, alpha_value=1.0, labels=None, axes_pair=None): """ Draw matches between a pair of images. Args: pts_left, pts_right: Corresponding keypoints for the two images (N, 2). line_colors: Colors for each match line, either as a string or an RGB tuple. If not provided, random colors will be generated. line_width: Width of the match lines. endpoint_size: Size of the endpoints (if 0, endpoints are not drawn). alpha_value: Opacity for the match lines. labels: Optional list of labels for each match. axes_pair: List of two axes [ax_left, ax_right] to plot the images; defaults to the first two axes in the current figure. """ fig = plt.gcf() if axes_pair is None: axs = fig.axes ax_left, ax_right = axs[0], axs[1] else: ax_left, ax_right = axes_pair if isinstance(pts_left, torch.Tensor): pts_left = pts_left.cpu().numpy() if isinstance(pts_right, torch.Tensor): pts_right = pts_right.cpu().numpy() assert len(pts_left) == len(pts_right) if line_colors is None: line_colors = matplotlib.cm.hsv(np.random.rand(len(pts_left))).tolist() elif len(line_colors) > 0 and not isinstance(line_colors[0], (tuple, list)): line_colors = [line_colors] * len(pts_left) if line_width > 0: for i in range(len(pts_left)): connector = matplotlib.patches.ConnectionPatch( xyA=(pts_left[i, 0], pts_left[i, 1]), xyB=(pts_right[i, 0], pts_right[i, 1]), coordsA=ax_left.transData, coordsB=ax_right.transData, axesA=ax_left, axesB=ax_right, zorder=1, color=line_colors[i], linewidth=line_width, clip_on=True, alpha=alpha_value, label=None if labels is None else labels[i], picker=5.0, ) connector.set_annotation_clip(True) fig.add_artist(connector) # Freeze axis autoscaling to prevent changes. ax_left.autoscale(enable=False) ax_right.autoscale(enable=False) if endpoint_size > 0: ax_left.scatter(pts_left[:, 0], pts_left[:, 1], c=line_colors, s=endpoint_size) ax_right.scatter(pts_right[:, 0], pts_right[:, 1], c=line_colors, s=endpoint_size) def add_text(axis_idx, text, pos=(0.01, 0.99), font_size=15, txt_color="w", border_color="k", border_width=2, h_align="left", v_align="top"): """ Add an annotation with an outline to a specified axis. Args: axis_idx: Index of the axis in the current figure where the annotation will be added. text: The annotation text. pos: Position of the annotation in axis coordinates (e.g., (0.01, 0.99)). font_size: Font size of the text. txt_color: Text color. border_color: Outline color (if None, no outline is applied). border_width: Width of the outline. h_align: Horizontal alignment (e.g., "left"). v_align: Vertical alignment (e.g., "top"). """ current_ax = plt.gcf().axes[axis_idx] annotation = current_ax.text( *pos, text, fontsize=font_size, ha=h_align, va=v_align, color=txt_color, transform=current_ax.transAxes ) if border_color is not None: annotation.set_path_effects([ peffects.Stroke(linewidth=border_width, foreground=border_color), peffects.Normal(), ])