File size: 7,190 Bytes
7e31006 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import bisect
import numpy as np
import cv2
def _compute_conf_thresh(data):
dataset_name = data["dataset_name"][0].lower()
if dataset_name == "scannet":
thr = 5e-4
elif dataset_name == "megadepth":
thr = 1e-4
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
return thr
# --- VISUALIZATION --- #
def make_matching_figure(
img0,
img1,
mkpts0,
mkpts1,
color,
kpts0=None,
kpts1=None,
text=[],
path=None,
):
"""
使用OpenCV绘制匹配点可视化图像
参数:
img0: 第一张图像 (BGR格式)
img1: 第二张图像 (BGR格式)
mkpts0: 第一张图像中的匹配点 (Nx2数组)
mkpts1: 第二张图像中的匹配点 (Nx2数组)
color: 每个匹配点的颜色
kpts0: 第一张图像中的所有关键点 (可选)
kpts1: 第二张图像中的所有关键点 (可选)
text: 要添加的文本 (可选)
path: 保存图像的路径 (可选)
返回:
绘制好的OpenCV图像 (BGR格式)
"""
# 确保匹配点数量一致
assert mkpts0.shape[0] == mkpts1.shape[0], \
f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
# 确保图像有相同的高度,如果不同则调整
h0, w0 = img0.shape[:2]
h1, w1 = img1.shape[:2]
max_height = max(h0, h1)
# 创建画布,两张图像并排显示
canvas = np.ones((max_height, w0 + w1, 3), dtype=np.uint8) * 255
# 将图像放置到画布上
canvas[:h0, :w0] = img0
canvas[:h1, w0:w0+w1] = img1
# 绘制所有关键点(如果提供)
if kpts0 is not None and kpts1 is not None:
for (x, y) in kpts0.astype(np.int32):
cv2.circle(canvas, (x, y), 1, (255, 255, 255), -1)
for (x, y) in kpts1.astype(np.int32):
cv2.circle(canvas, (x + w0, y), 1, (255, 255, 255), -1)
# 绘制匹配点和连接线
if mkpts0.shape[0] > 0 and mkpts1.shape[0] > 0:
# 转换为整数坐标
mkpts0_int = mkpts0.astype(np.int32)
mkpts1_int = mkpts1.astype(np.int32)
# 绘制连接线
for i in range(len(mkpts0_int)):
x0, y0 = mkpts0_int[i]
x1, y1 = mkpts1_int[i]
# 第二张图的x坐标需要加上第一张图的宽度
x1 += w0
# 将颜色从0-1范围转换为0-255
line_color = tuple(int(c * 255) for c in color[i][:3])
# 转换为BGR格式(因为OpenCV使用BGR)
# line_color = (line_color[2], line_color[1], line_color[0])
cv2.line(canvas, (x0, y0), (x1, y1), line_color, 1)
# 绘制匹配点
for i in range(len(mkpts0_int)):
x0, y0 = mkpts0_int[i]
x1, y1 = mkpts1_int[i]
x1 += w0
pt_color = tuple(int(c * 255) for c in color[i][:3])
# pt_color = (pt_color[2], pt_color[1], pt_color[0])
cv2.circle(canvas, (x0, y0), 2, pt_color, -1)
cv2.circle(canvas, (x1, y1), 2, pt_color, -1)
# 添加文本
if text:
# 确定文本颜色(基于图像亮度)
roi = img0[:100, :200] if h0 > 100 and w0 > 200 else img0
brightness = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY).mean()
text_color = (0, 0, 0) if brightness > 200 else (255, 255, 255)
# 绘制文本
y_pos = 30
for i, line in enumerate(text):
cv2.putText(
canvas, line, (10, y_pos + i * 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, text_color, 2
)
# 保存图像(如果指定了路径)
if path:
cv2.imwrite(path, canvas)
return canvas
def _make_evaluation_figure(data, b_id, alpha="dynamic"):
b_mask = data["m_bids"] == b_id
conf_thr = _compute_conf_thresh(data)
img0 = (data["image0"][b_id][0].cpu().numpy()
* 255).round().astype(np.int32)
img1 = (data["image1"][b_id][0].cpu().numpy()
* 255).round().astype(np.int32)
kpts0 = data["mkpts0_f"][b_mask].clone().detach().cpu().numpy()
kpts1 = data["mkpts1_f"][b_mask].clone().detach().cpu().numpy()
# for megadepth, we visualize matches on the resized image
if "scale0" in data:
kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
epi_errs = data["epi_errs"][b_mask].cpu().numpy()
correct_mask = epi_errs < conf_thr
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
n_correct = np.sum(correct_mask)
n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
# recall might be larger than 1, since the calculation of conf_matrix_gt
# uses groundtruth depths and camera poses, but epipolar distance is used here.
# matching info
if alpha == "dynamic":
alpha = dynamic_alpha(len(correct_mask))
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
text = [
f"#Matches {len(kpts0)}",
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
]
# make the figure
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
return figure
def _make_confidence_figure(data, b_id):
# TODO: Implement confidence figure
raise NotImplementedError()
def make_matching_figures(data, config, mode="evaluation"):
"""Make matching figures for a batch.
Args:
data (Dict): a batch updated by PL_LoFTR.
config (Dict): matcher config
Returns:
figures (Dict[str, List[plt.figure]]
"""
assert mode in ["evaluation", "confidence", "gt"] # 'confidence'
figures = {mode: []}
for b_id in range(data["image0"].size(0)):
if mode == "evaluation":
fig = _make_evaluation_figure(
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
)
elif mode == "confidence":
fig = _make_confidence_figure(data, b_id)
else:
raise ValueError(f"Unknown plot mode: {mode}")
figures[mode].append(fig)
return figures
def dynamic_alpha(
n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
):
if n_matches == 0:
return 1.0
ranges = list(zip(alphas, alphas[1:] + [None]))
loc = bisect.bisect_right(milestones, n_matches) - 1
_range = ranges[loc]
if _range[1] is None:
return _range[0]
return _range[1] + (milestones[loc + 1] - n_matches) / (
milestones[loc + 1] - milestones[loc]
) * (_range[0] - _range[1])
def error_colormap(err, thr, alpha=1.0):
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
x = 1 - np.clip(err / (thr * 2), 0, 1)
return np.clip(
np.stack([2 - x * 2, x * 2, np.zeros_like(x),
np.ones_like(x) * alpha], -1),
0,
1,
)
|