File size: 6,470 Bytes
a5b5add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import matplotlib
import matplotlib.patheffects as peffects
import matplotlib.pyplot as plt
import numpy as np
import torch

def show_images(image_list, titles=None, colormaps="gray", dpi=100, pad=0.5, auto_size=True):
    """

    Display a set of images horizontally.



    Args:

        image_list: List of images in either NumPy RGB (H, W, 3), 

                    PyTorch RGB (3, H, W) or grayscale (H, W) format.

        titles: List of titles for each image.

        colormaps: Colormap for grayscale images.

        dpi: Figure resolution.

        pad: Padding between images.

        auto_size: Whether the figure size should adapt to the images' aspect ratios.

    """
    # Convert torch.Tensor images to NumPy arrays in (H, W, 3) format.
    image_list = [
        img.permute(1, 2, 0).cpu().numpy()
        if (isinstance(img, torch.Tensor) and img.dim() == 3)
        else img
        for img in image_list
    ]
    num_imgs = len(image_list)
    if not isinstance(colormaps, (list, tuple)):
        colormaps = [colormaps] * num_imgs

    if auto_size:
        ratios = [im.shape[1] / im.shape[0] for im in image_list]  # width / height
    else:
        ratios = [4 / 3] * num_imgs
    fig_size = [sum(ratios) * 4.5, 4.5]
    fig, axes = plt.subplots(1, num_imgs, figsize=fig_size, dpi=dpi, gridspec_kw={"width_ratios": ratios})
    if num_imgs == 1:
        axes = [axes]
    for i in range(num_imgs):
        axes[i].imshow(image_list[i], cmap=plt.get_cmap(colormaps[i]))
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        axes[i].set_axis_off()
        for spine in axes[i].spines.values():
            spine.set_visible(False)
        if titles:
            axes[i].set_title(titles[i])
    fig.tight_layout(pad=pad)


def draw_keypoints(keypoints, kp_color="lime", kp_size=4, ax_list=None, alpha_value=1.0):
    """

    Plot keypoints on existing images.



    Args:

        keypoints: List of ndarrays (N, 2) for each set of keypoints.

        kp_color: Color for keypoints, or list of colors for each set.

        kp_size: Size of keypoints.

        ax_list: List of axes to plot keypoints on; defaults to current figure's axes.

        alpha_value: Opacity for keypoints.

    """
    if not isinstance(kp_color, list):
        kp_color = [kp_color] * len(keypoints)
    if not isinstance(alpha_value, list):
        alpha_value = [alpha_value] * len(keypoints)
    if ax_list is None:
        ax_list = plt.gcf().axes
    for ax, pts, color, alpha in zip(ax_list, keypoints, kp_color, alpha_value):
        if isinstance(pts, torch.Tensor):
            pts = pts.cpu().numpy()
        ax.scatter(pts[:, 0], pts[:, 1], c=color, s=kp_size, linewidths=0, alpha=alpha)


def draw_matches(pts_left, pts_right, line_colors=None, line_width=1.5, endpoint_size=4, alpha_value=1.0, labels=None, axes_pair=None):
    """

    Draw matches between a pair of images.



    Args:

        pts_left, pts_right: Corresponding keypoints for the two images (N, 2).

        line_colors: Colors for each match line, either as a string or an RGB tuple.

                     If not provided, random colors will be generated.

        line_width: Width of the match lines.

        endpoint_size: Size of the endpoints (if 0, endpoints are not drawn).

        alpha_value: Opacity for the match lines.

        labels: Optional list of labels for each match.

        axes_pair: List of two axes [ax_left, ax_right] to plot the images; defaults to the first two axes in the current figure.

    """
    fig = plt.gcf()
    if axes_pair is None:
        axs = fig.axes
        ax_left, ax_right = axs[0], axs[1]
    else:
        ax_left, ax_right = axes_pair
    if isinstance(pts_left, torch.Tensor):
        pts_left = pts_left.cpu().numpy()
    if isinstance(pts_right, torch.Tensor):
        pts_right = pts_right.cpu().numpy()
    assert len(pts_left) == len(pts_right)
    if line_colors is None:
        line_colors = matplotlib.cm.hsv(np.random.rand(len(pts_left))).tolist()
    elif len(line_colors) > 0 and not isinstance(line_colors[0], (tuple, list)):
        line_colors = [line_colors] * len(pts_left)

    if line_width > 0:
        for i in range(len(pts_left)):
            connector = matplotlib.patches.ConnectionPatch(
                xyA=(pts_left[i, 0], pts_left[i, 1]),
                xyB=(pts_right[i, 0], pts_right[i, 1]),
                coordsA=ax_left.transData,
                coordsB=ax_right.transData,
                axesA=ax_left,
                axesB=ax_right,
                zorder=1,
                color=line_colors[i],
                linewidth=line_width,
                clip_on=True,
                alpha=alpha_value,
                label=None if labels is None else labels[i],
                picker=5.0,
            )
            connector.set_annotation_clip(True)
            fig.add_artist(connector)

    # Freeze axis autoscaling to prevent changes.
    ax_left.autoscale(enable=False)
    ax_right.autoscale(enable=False)

    if endpoint_size > 0:
        ax_left.scatter(pts_left[:, 0], pts_left[:, 1], c=line_colors, s=endpoint_size)
        ax_right.scatter(pts_right[:, 0], pts_right[:, 1], c=line_colors, s=endpoint_size)


def add_text(axis_idx, text, pos=(0.01, 0.99), font_size=15, txt_color="w", border_color="k", border_width=2, h_align="left", v_align="top"):
    """

    Add an annotation with an outline to a specified axis.



    Args:

        axis_idx: Index of the axis in the current figure where the annotation will be added.

        text: The annotation text.

        pos: Position of the annotation in axis coordinates (e.g., (0.01, 0.99)).

        font_size: Font size of the text.

        txt_color: Text color.

        border_color: Outline color (if None, no outline is applied).

        border_width: Width of the outline.

        h_align: Horizontal alignment (e.g., "left").

        v_align: Vertical alignment (e.g., "top").

    """
    current_ax = plt.gcf().axes[axis_idx]
    annotation = current_ax.text(
        *pos, text, fontsize=font_size, ha=h_align, va=v_align, color=txt_color, transform=current_ax.transAxes
    )
    if border_color is not None:
        annotation.set_path_effects([
            peffects.Stroke(linewidth=border_width, foreground=border_color),
            peffects.Normal(),
        ])