honeytian commited on
Commit
071150e
·
0 Parent(s):

first commit

Browse files
.dockerignore ADDED
File without changes
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .idea/
3
+
4
+ __pycache__/
5
+
6
+ data/
7
+ logs/
8
+ temp/
9
+ trainede_models/
10
+
11
+
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN apt-get update
8
+ RUN apt-get install -y ffmpeg build-essential
9
+ RUN apt-get install -y libnss3
10
+
11
+ RUN pip install --upgrade pip
12
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
13
+
14
+ RUN useradd -m -u 1000 user
15
+ USER user
16
+ ENV HOME=/home/user \
17
+ PATH=/home/user/.local/bin:$PATH
18
+
19
+ WORKDIR $HOME/app
20
+
21
+ COPY --chown=user . $HOME/app
22
+
23
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: KeyPoints
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ short_description: KeyPoints
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
examples/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
examples/keypoint_match/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
examples/keypoint_match/sift_multi_image_match.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 示例:使用 SIFT 提取关键点 + 多模板匹配(不做仿射/单应性)。
6
+
7
+ 运行:
8
+ python -m examples.keypoint_match.sift_multi_image_match
9
+
10
+ 或指定多个模板:
11
+ python -m examples.keypoint_match.sift_multi_image_match --template_paths a.png b.png --search_path big.png
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+ from typing import List, Tuple
19
+
20
+ import cv2
21
+ import numpy as np
22
+
23
+ from project_settings import project_path
24
+ from toolbox.keypoint_match import (
25
+ MultiImageDetector,
26
+ MultiImageDetectorConfig,
27
+ SiftExtractConfig,
28
+ SiftKeyPointExtract,
29
+ )
30
+
31
+
32
+ def _crop_template(image_bgr: np.ndarray, xyxy: Tuple[int, int, int, int]) -> np.ndarray:
33
+ x1, y1, x2, y2 = [int(v) for v in xyxy]
34
+ h, w = image_bgr.shape[:2]
35
+ x1 = max(0, min(w - 1, x1))
36
+ x2 = max(0, min(w, x2))
37
+ y1 = max(0, min(h - 1, y1))
38
+ y2 = max(0, min(h, y2))
39
+ if x2 <= x1 or y2 <= y1:
40
+ raise ValueError(f"无效裁剪框: {xyxy}, image_size={(h, w)}")
41
+ return image_bgr[y1:y2, x1:x2].copy()
42
+
43
+
44
+ def get_args():
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument(
47
+ "--search_path",
48
+ type=str,
49
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
50
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard1.jpg").as_posix(),
51
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard2.jpg").as_posix(),
52
+ default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard3.jpg").as_posix(),
53
+ )
54
+ parser.add_argument(
55
+ "--template_paths",
56
+ type=str,
57
+ nargs="*",
58
+ default=[
59
+ (project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
60
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
61
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
62
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
63
+ ],
64
+ help="提供多个小图模板路径;为空则使用默认模板 + 从大图裁剪一个 patch 当作第二模板",
65
+ )
66
+ parser.add_argument("--max_keypoints", type=int, default=2000)
67
+ parser.add_argument("--ratio", type=float, default=0.95)
68
+ parser.add_argument("--max_matches", type=int, default=120)
69
+ return parser.parse_args()
70
+
71
+
72
+ def main():
73
+ args = get_args()
74
+
75
+ search = cv2.imread(args.search_path)
76
+ if search is None:
77
+ raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
78
+
79
+ templates: List[np.ndarray] = []
80
+ template_ids: List[str] = []
81
+
82
+ for p in args.template_paths:
83
+ img = cv2.imread(p)
84
+ if img is None:
85
+ raise FileNotFoundError(f"无法读取模板图: {p}")
86
+ templates.append(img)
87
+ template_ids.append(Path(p).stem)
88
+
89
+ extractor = SiftKeyPointExtract(
90
+ SiftExtractConfig(
91
+ max_keypoints=int(args.max_keypoints),
92
+ )
93
+ )
94
+
95
+ detector = MultiImageDetector(
96
+ extractor=extractor,
97
+ config=MultiImageDetectorConfig(
98
+ ratio=float(args.ratio),
99
+ max_matches=int(args.max_matches),
100
+ max_keypoints=int(args.max_keypoints),
101
+ ),
102
+ )
103
+
104
+ result = detector.detect(templates, search, template_ids=template_ids)
105
+
106
+ for it in result.items:
107
+ print(f"template={it.template_id} kp_t={it.template_kp.n} kp_s={it.search_kp.n} matches={it.matches.m}")
108
+
109
+ title = "sift_multi_image_match | " + " , ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
110
+ cv2.imshow(title, result.vis_search)
111
+ cv2.waitKey(0)
112
+ cv2.destroyAllWindows()
113
+ return
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
118
+
examples/keypoint_match/superpoint_multi_image_match.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 示例:使用 SuperPoint 提取关键点 + 多模板匹配(不做仿射/单应性)。
6
+
7
+ 运行:
8
+ python -m examples.keypoint_match.superpoint_multi_image_match
9
+
10
+ 或指定多个模板:
11
+ python -m examples.keypoint_match.superpoint_multi_image_match --template_paths a.png b.png --search_path big.png
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+ from typing import List, Tuple
19
+
20
+ import cv2
21
+ import numpy as np
22
+
23
+ from project_settings import project_path
24
+ from toolbox.keypoint_match import (
25
+ MultiImageDetector,
26
+ MultiImageDetectorConfig,
27
+ SuperPointExtractConfig,
28
+ SuperPointKeyPointExtract,
29
+ )
30
+
31
+
32
+ def _crop_template(image_bgr: np.ndarray, xyxy: Tuple[int, int, int, int]) -> np.ndarray:
33
+ x1, y1, x2, y2 = [int(v) for v in xyxy]
34
+ h, w = image_bgr.shape[:2]
35
+ x1 = max(0, min(w - 1, x1))
36
+ x2 = max(0, min(w, x2))
37
+ y1 = max(0, min(h - 1, y1))
38
+ y2 = max(0, min(h, y2))
39
+ if x2 <= x1 or y2 <= y1:
40
+ raise ValueError(f"无效裁剪框: {xyxy}, image_size={(h, w)}")
41
+ return image_bgr[y1:y2, x1:x2].copy()
42
+
43
+
44
+ def get_args():
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument(
47
+ "--search_path",
48
+ type=str,
49
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
50
+ )
51
+ parser.add_argument(
52
+ "--template_paths",
53
+ type=str,
54
+ nargs="*",
55
+ default=[
56
+ (project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
57
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
58
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
59
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
60
+ ],
61
+ help="提供多个小图模板路径;为空则使用默认模板 + 从大图裁剪一个 patch 当作第二模板",
62
+ )
63
+ parser.add_argument("--device", type=str, default="cpu")
64
+ parser.add_argument("--max_keypoints", type=int, default=2000)
65
+ parser.add_argument("--ratio", type=float, default=0.99)
66
+ parser.add_argument("--max_matches", type=int, default=120)
67
+ return parser.parse_args()
68
+
69
+
70
+ def main():
71
+ args = get_args()
72
+
73
+ search = cv2.imread(args.search_path)
74
+ if search is None:
75
+ raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
76
+
77
+ templates: List[np.ndarray] = []
78
+ template_ids: List[str] = []
79
+
80
+ for p in args.template_paths:
81
+ img = cv2.imread(p)
82
+ if img is None:
83
+ raise FileNotFoundError(f"无法读取模板图: {p}")
84
+ templates.append(img)
85
+ template_ids.append(Path(p).stem)
86
+
87
+ extractor = SuperPointKeyPointExtract(
88
+ SuperPointExtractConfig(
89
+ device=str(args.device),
90
+ max_keypoints=int(args.max_keypoints),
91
+ )
92
+ )
93
+
94
+ detector = MultiImageDetector(
95
+ extractor=extractor,
96
+ config=MultiImageDetectorConfig(
97
+ ratio=float(args.ratio),
98
+ max_matches=int(args.max_matches),
99
+ max_keypoints=int(args.max_keypoints),
100
+ ),
101
+ )
102
+
103
+ result = detector.detect(templates, search, template_ids=template_ids)
104
+
105
+ # 打印每个模板的匹配数量(越多通常说明越像)
106
+ for it in result.items:
107
+ print(f"template={it.template_id} kp_t={it.template_kp.n} kp_s={it.search_kp.n} matches={it.matches.m}")
108
+
109
+ title = "superpoint_multi_image_match | " + " , ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
110
+ cv2.imshow(title, result.vis_search)
111
+ cv2.waitKey(0)
112
+ cv2.destroyAllWindows()
113
+ return
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
118
+
examples/keypoint_match/video_track_and_collect_templates.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 通过视频录制目标多角度,用跟踪器持续估计目标在画面中的位置与大致范围(边界框),
6
+ 便于按帧裁剪保存为后续关键点匹配的模板小图。
7
+
8
+ 用法:
9
+ # 摄像头(Windows 上设备号多为 0)
10
+ python -m examples.keypoint_match.video_track_and_collect_templates --source 0
11
+
12
+ # 视频文件
13
+ python -m examples.keypoint_match.video_track_and_collect_templates --source path/to/video.mp4
14
+
15
+ 操作:
16
+ - 第一帧(或按 r 重选):按住鼠标左键拖拽画矩形框选目标,松开鼠标后自动开始跟踪
17
+ - s:将当前帧中跟踪框区域(可带边距)裁剪保存到 --output_dir
18
+ - r:进入重选模式,在当前帧重新拖拽框选(跟踪漂移或丢目标时使用)
19
+ - 空格:暂停/继续
20
+ - q 或 ESC:退出
21
+
22
+ 跟踪器:
23
+ 默认 mil(OpenCV 自带)。若本机有 Nano 权重可试 nano(需自行准备模型路径,见 --help)。
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import sys
30
+ import time
31
+ from pathlib import Path
32
+ from typing import Optional, Tuple
33
+
34
+ import cv2
35
+ import numpy as np
36
+
37
+ from project_settings import temp_directory
38
+
39
+
40
+ Rect = Tuple[int, int, int, int] # x, y, w, h
41
+
42
+
43
+ def create_tracker(kind: str, nano_model: Optional[str] = None, nano_backend: Optional[str] = None):
44
+ """创建 OpenCV Tracker,按环境能力回退。"""
45
+ kind = (kind or "mil").lower().strip()
46
+
47
+ if kind == "nano":
48
+ if not hasattr(cv2, "TrackerNano_create"):
49
+ raise RuntimeError("当前 OpenCV 未提供 TrackerNano_create,请改用 --tracker mil")
50
+ params = cv2.TrackerNano_Params()
51
+ if nano_model:
52
+ params.backbone = nano_model
53
+ if nano_backend:
54
+ params.NNBackend = nano_backend
55
+ return cv2.TrackerNano_create(params)
56
+
57
+ if kind == "goturn":
58
+ if not hasattr(cv2, "TrackerGOTURN_create"):
59
+ raise RuntimeError("当前 OpenCV 未提供 TrackerGOTURN_create")
60
+ return cv2.TrackerGOTURN_create()
61
+
62
+ if kind == "dasiamrpn":
63
+ if not hasattr(cv2, "TrackerDaSiamRPN_create"):
64
+ raise RuntimeError("当前 OpenCV 未提供 TrackerDaSiamRPN_create")
65
+ return cv2.TrackerDaSiamRPN_create()
66
+
67
+ if kind == "mil":
68
+ return cv2.TrackerMIL_create()
69
+
70
+ raise ValueError(f"未知 tracker 类型: {kind},可选: mil, nano, goturn, dasiamrpn")
71
+
72
+
73
+ def clamp_rect(x: int, y: int, w: int, h: int, frame_w: int, frame_h: int) -> Rect:
74
+ x = max(0, min(x, frame_w - 1))
75
+ y = max(0, min(y, frame_h - 1))
76
+ w = max(1, min(w, frame_w - x))
77
+ h = max(1, min(h, frame_h - y))
78
+ return x, y, w, h
79
+
80
+
81
+ class ROISelector:
82
+ """鼠标拖拽选框。"""
83
+
84
+ def __init__(self, window_name: str):
85
+ self.window_name = window_name
86
+ self.drawing = False
87
+ self.x0 = self.y0 = self.x1 = self.y1 = 0
88
+
89
+ def on_mouse(self, event, mx, my, flags, param):
90
+ if event == cv2.EVENT_LBUTTONDOWN:
91
+ self.drawing = True
92
+ self.x0 = self.x1 = mx
93
+ self.y0 = self.y1 = my
94
+ elif event == cv2.EVENT_MOUSEMOVE and self.drawing:
95
+ self.x1, self.y1 = mx, my
96
+ elif event == cv2.EVENT_LBUTTONUP:
97
+ self.drawing = False
98
+ self.x1, self.y1 = mx, my
99
+
100
+ def attach(self):
101
+ cv2.setMouseCallback(self.window_name, self.on_mouse)
102
+
103
+ def current_rect(self, frame_w: int, frame_h: int) -> Optional[Rect]:
104
+ x1, y1, x2, y2 = self.x0, self.y0, self.x1, self.y1
105
+ if abs(x2 - x1) < 3 or abs(y2 - y1) < 3:
106
+ return None
107
+ xa, xb = sorted((x1, x2))
108
+ ya, yb = sorted((y1, y2))
109
+ return clamp_rect(xa, ya, xb - xa, yb - ya, frame_w, frame_h)
110
+
111
+ def draw_preview(self, frame: np.ndarray) -> np.ndarray:
112
+ vis = frame.copy()
113
+ if self.drawing or abs(self.x1 - self.x0) > 1:
114
+ xa, xb = sorted((self.x0, self.x1))
115
+ ya, yb = sorted((self.y0, self.y1))
116
+ cv2.rectangle(vis, (xa, ya), (xb, yb), (0, 255, 255), 2)
117
+ return vis
118
+
119
+
120
+ def open_capture(source: str) -> cv2.VideoCapture:
121
+ if source.isdigit():
122
+ cap = cv2.VideoCapture(int(source))
123
+ else:
124
+ cap = cv2.VideoCapture(source)
125
+ if not cap.isOpened():
126
+ raise RuntimeError(f"无法打开视频源: {source}")
127
+ return cap
128
+
129
+
130
+ def get_args():
131
+ p = argparse.ArgumentParser(description="视频目标跟踪 + 裁剪保存模板小图")
132
+ p.add_argument(
133
+ "--source",
134
+ type=str,
135
+ default="0",
136
+ help="摄像头设备号(如 0)或视频文件路径",
137
+ )
138
+ p.add_argument(
139
+ "--output_dir",
140
+ type=str,
141
+ default=(temp_directory / "template_crops").as_posix(),
142
+ help="按 s 保存裁剪图到此目录",
143
+ )
144
+ p.add_argument(
145
+ "--tracker",
146
+ type=str,
147
+ default="mil",
148
+ choices=["mil", "nano", "goturn", "dasiamrpn"],
149
+ help="OpenCV 跟踪器��型(默认可用 mil)",
150
+ )
151
+ p.add_argument("--nano_model", type=str, default="", help="TrackerNano backbone 模型路径(可选)")
152
+ p.add_argument("--nano_backend", type=str, default="", help="TrackerNano NNBackend(可选,依 OpenCV 文档)")
153
+ p.add_argument(
154
+ "--crop_pad",
155
+ type=float,
156
+ default=0.08,
157
+ help="保存裁剪时在框四周按比例扩展(相对框宽高),便于包含上下文",
158
+ )
159
+ return p.parse_args()
160
+
161
+
162
+ def expand_rect(rect: Rect, pad_ratio: float, fw: int, fh: int) -> Rect:
163
+ x, y, w, h = rect
164
+ px = int(w * pad_ratio)
165
+ py = int(h * pad_ratio)
166
+ return clamp_rect(x - px, y - py, w + 2 * px, h + 2 * py, fw, fh)
167
+
168
+
169
+ def main():
170
+ args = get_args()
171
+ out_dir = Path(args.output_dir)
172
+ out_dir.mkdir(parents=True, exist_ok=True)
173
+
174
+ cap = open_capture(args.source)
175
+ win = "track_and_collect | drag ROI then release | s save | r reselect | space pause | q quit"
176
+ cv2.namedWindow(win, cv2.WINDOW_NORMAL)
177
+ selector = ROISelector(win)
178
+ selector.attach()
179
+
180
+ tracker = None
181
+ paused = False
182
+ need_init = True
183
+ ok_track = False
184
+ bbox: Optional[Rect] = None
185
+ save_idx = 0
186
+ nano_model = args.nano_model or None
187
+ nano_backend = args.nano_backend or None
188
+ prev_drawing = False
189
+
190
+ print(__doc__)
191
+ print(f"输出目录: {out_dir.resolve()}")
192
+
193
+ while True:
194
+ if not paused:
195
+ ret, frame = cap.read()
196
+ if not ret or frame is None:
197
+ print("视频结束或读取失败")
198
+ break
199
+ fh, fw = frame.shape[:2]
200
+ else:
201
+ fh, fw = frame.shape[:2]
202
+
203
+ key = cv2.waitKey(1) & 0xFF
204
+ if key in (ord("q"), 27):
205
+ break
206
+ if key == ord(" "):
207
+ paused = not paused
208
+ continue
209
+ if key == ord("r"):
210
+ need_init = True
211
+ tracker = None
212
+ ok_track = False
213
+ paused = True
214
+ print("重选模式:在本帧拖拽框选目标,松开鼠标后开始跟踪;可按空格继续/暂停")
215
+
216
+ display = frame.copy()
217
+
218
+ if need_init:
219
+ display = selector.draw_preview(display)
220
+ preview_rect = selector.current_rect(fw, fh)
221
+ if preview_rect is not None:
222
+ cv2.rectangle(
223
+ display,
224
+ (preview_rect[0], preview_rect[1]),
225
+ (preview_rect[0] + preview_rect[2], preview_rect[1] + preview_rect[3]),
226
+ (0, 255, 0),
227
+ 2,
228
+ )
229
+ # 松开鼠标:从拖拽 -> 非拖拽 的瞬间,若框有效则自动初始化跟踪器
230
+ if prev_drawing and not selector.drawing:
231
+ rect = selector.current_rect(fw, fh)
232
+ if rect is not None:
233
+ tracker = create_tracker(args.tracker, nano_model=nano_model, nano_backend=nano_backend)
234
+ tracker.init(frame, rect)
235
+ bbox = rect
236
+ need_init = False
237
+ ok_track = True
238
+ paused = False
239
+ print(f"已初始化跟踪,bbox= {rect}")
240
+ else:
241
+ if tracker is not None and not paused:
242
+ ok_track, bbox = tracker.update(frame)
243
+ if bbox is not None:
244
+ x, y, w, h = [int(v) for v in bbox]
245
+ x, y, w, h = clamp_rect(x, y, w, h, fw, fh)
246
+ color = (0, 255, 0) if ok_track else (0, 0, 255)
247
+ cv2.rectangle(display, (x, y), (x + w, y + h), color, 2)
248
+ label = "tracking OK" if ok_track else "tracking LOST?"
249
+ cv2.putText(display, label, (x, max(0, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
250
+
251
+ if key == ord("s") and bbox is not None and ok_track:
252
+ x, y, w, h = [int(v) for v in bbox]
253
+ x, y, w, h = clamp_rect(x, y, w, h, fw, fh)
254
+ ex = expand_rect((x, y, w, h), float(args.crop_pad), fw, fh)
255
+ crop = frame[ex[1] : ex[1] + ex[3], ex[0] : ex[0] + ex[2]]
256
+ ts = int(time.time() * 1000)
257
+ path = out_dir / f"template_{save_idx:04d}_{ts}.png"
258
+ cv2.imwrite(str(path), crop)
259
+ save_idx += 1
260
+ print(f"已保存: {path}")
261
+
262
+ hint = (
263
+ "Drag ROI (release=init) | s save | r reselect | space pause | q quit"
264
+ if need_init
265
+ else "s save | r reselect | space pause | q quit"
266
+ )
267
+ cv2.putText(display, hint, (10, fh - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200, 200, 200), 1, cv2.LINE_AA)
268
+ cv2.imshow(win, display)
269
+
270
+ prev_drawing = selector.drawing
271
+
272
+ cap.release()
273
+ cv2.destroyAllWindows()
274
+ return 0
275
+
276
+
277
+ if __name__ == "__main__":
278
+ sys.exit(main())
examples/keypoints/kornia/test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/kornia/kornia
3
+ """
4
+ import cv2
5
+ import torch
6
+ import kornia as K
7
+ import kornia.feature as KF
8
+ import matplotlib.pyplot as plt
9
+ from kornia_moons.viz import draw_LAF_matches
10
+
11
+
12
+ def match_keyboards(img_path1, img_path2):
13
+ # 1. 加载图片并转换为 Tensor (Kornia 处理的是 Tensor)
14
+ # 键盘细节丰富,建议保持较高分辨率,不要缩放得太厉害
15
+ img1 = K.image_to_tensor(cv2.imread(img_path1), keepdim=False).float() / 255.0
16
+ img2 = K.image_to_tensor(cv2.imread(img_path2), keepdim=False).float() / 255.0
17
+
18
+ # 转换为灰度图用于特征提取
19
+ img1_gray = K.color.rgb_to_grayscale(K.color.bgr_to_rgb(img1))
20
+ img2_gray = K.color.rgb_to_grayscale(K.color.bgr_to_rgb(img2))
21
+
22
+ # 2. 定义特征提取器 (KeyNet + AffNet + HardNet)
23
+ # 这是一套非常强大的组合,对键盘这种重复纹理有很好的识别力
24
+ num_features = 2000 # 键盘字符多,建议点数设多一点
25
+ matcher = KF.LocalFeatureMatcher(
26
+ KF.KeyNetAffNetHardNet(num_features, pretrained=True),
27
+ KF.DescriptorMatcher('adalam', torch.device('cpu')) # 'adalam' 是目前非常快的离群点过滤匹配算法
28
+ )
29
+
30
+ # 3. 执行匹配
31
+ input_dict = {"image0": img1_gray, "image1": img2_gray}
32
+ with torch.no_grad():
33
+ correspondences = matcher(input_dict)
34
+
35
+ # 提取匹配后的坐标点
36
+ mkpts0 = correspondences['keypoints0'].cpu().numpy()
37
+ mkpts1 = correspondences['keypoints1'].cpu().numpy()
38
+
39
+ print(f"找到匹配点数量: {len(mkpts0)}")
40
+
41
+ # 4. 可视化结果
42
+ # 我们用 OpenCV 把两张图拼在一起看连线
43
+ img1_cv = cv2.cvtColor(cv2.imread(img_path1), cv2.COLOR_BGR2RGB)
44
+ img2_cv = cv2.cvtColor(cv2.imread(img_path2), cv2.COLOR_BGR2RGB)
45
+
46
+ # 简单的可视化:画出前 50 个最强的匹配
47
+ fig, ax = plt.subplots(figsize=(15, 8))
48
+ draw_LAF_matches(
49
+ KF.laf_from_center_scale_ori(correspondences['keypoints0'].view(1, -1, 2)),
50
+ KF.laf_from_center_scale_ori(correspondences['keypoints1'].view(1, -1, 2)),
51
+ torch.arange(len(mkpts0)).view(1, -1, 1),
52
+ img1_cv,
53
+ img2_cv,
54
+ draw_dict={'inliers_fallback': True}
55
+ )
56
+ plt.show()
57
+
58
+ # 5. 判断逻辑
59
+ # 如果匹配点数量很多(比如 > 100)且连线平行度高,则极有可能是同一个键盘
60
+ if len(mkpts0) > 50:
61
+ print("结论:两张图片匹配度极高,很可能是同一个键盘或极其相似。")
62
+ else:
63
+ print("结论:匹配点过少,可能是不同键盘或拍摄角度偏差过大。")
64
+
65
+ # 使用示例
66
+ match_keyboards(
67
+ 'keyboard_left.jpg',
68
+ 'keyboard_right.jpg'
69
+ )
examples/keypoints/superpoint/test.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/rpautrat/SuperPointPretrainedNetwork
3
+
4
+ """
5
+ import argparse
6
+ import os
7
+
8
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
9
+
10
+ import torch
11
+ import cv2
12
+ import numpy as np
13
+ import requests
14
+ from PIL import Image
15
+ from transformers import AutoImageProcessor
16
+ from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput
17
+ from project_settings import project_path, temp_directory
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--model_name",
24
+ default="magic-leap-community/superpoint",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--model_cache_dir",
29
+ default=(project_path / "../../hf_hub_models").as_posix(),
30
+ type=str
31
+ )
32
+ parser.add_argument(
33
+ "--image_path",
34
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
35
+ type=str
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def show_image(image):
42
+ cv2.imshow("image", image)
43
+ cv2.waitKey(0)
44
+ cv2.destroyAllWindows()
45
+
46
+
47
+ def main():
48
+ args = get_args()
49
+
50
+ processor = AutoImageProcessor.from_pretrained(
51
+ pretrained_model_name_or_path=args.model_name,
52
+ cache_dir=args.model_cache_dir,
53
+ )
54
+ model = SuperPointForKeypointDetection.from_pretrained(
55
+ pretrained_model_name_or_path=args.model_name,
56
+ cache_dir=args.model_cache_dir,
57
+ )
58
+
59
+ image = Image.open(args.image_path).convert("RGB")
60
+
61
+ inputs = processor(image, return_tensors="pt")
62
+ output: SuperPointKeypointDescriptionOutput = model(**inputs)
63
+
64
+ # 使用 processor 的后处理,将相对坐标转换为像素坐标
65
+ image_size = (image.height, image.width)
66
+ processed = processor.post_process_keypoint_detection(
67
+ output,
68
+ [image_size],
69
+ )
70
+ # processed 是长度为 batch_size 的 list,这里只有一张图
71
+ keypoints = processed[0]["keypoints"] # [N, 2],(x, y) 为像素坐标
72
+ scores = processed[0]["scores"] # [N]
73
+ descriptors = processed[0]["descriptors"] # [N, D]
74
+ scores = scores.detach().cpu().numpy()
75
+
76
+ print(f"检测到关键点数量: {keypoints.shape[0]}")
77
+ print(f"描述符维度: {descriptors.shape}")
78
+
79
+ # 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示
80
+ # PIL 图像 -> numpy -> BGR
81
+ image_np = np.array(image) # RGB
82
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
83
+
84
+ # 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表
85
+ cv2_keypoints = []
86
+ for (x, y), score in zip(keypoints, scores):
87
+ # x, y 是像素坐标;score 作为响应值
88
+ # OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线”
89
+ kp = cv2.KeyPoint(
90
+ x=float(x),
91
+ y=float(y),
92
+ size=7,
93
+ response=float(score),
94
+ )
95
+ cv2_keypoints.append(kp)
96
+
97
+ # 使用 drawKeypoints 画关键点
98
+ image_with_kp = cv2.drawKeypoints(
99
+ image_bgr,
100
+ cv2_keypoints,
101
+ None,
102
+ color=(0, 0, 255),
103
+ flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
104
+ )
105
+
106
+ show_image(image_with_kp)
107
+ return
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
main.py ADDED
File without changes
project_settings.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ time_zone_info = "Asia/Shanghai"
13
+
14
+ log_directory = project_path / "logs"
15
+ log_directory.mkdir(parents=True, exist_ok=True)
16
+
17
+ temp_directory = project_path / "temp"
18
+ temp_directory.mkdir(parents=True, exist_ok=True)
19
+
20
+ environment = EnvironmentManager(
21
+ path=os.path.join(project_path, "dotenv"),
22
+ env=os.environ.get("environment", "dev"),
23
+ )
24
+
25
+
26
+ if __name__ == "__main__":
27
+ pass
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ transformers
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ Pillow
7
+ requests
toolbox/__init__.py ADDED
File without changes
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/keypoint_match/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from toolbox.keypoint_match.types import (
5
+ Detection,
6
+ DetectionResult,
7
+ ImageArray,
8
+ KeyPointSet,
9
+ MatchSet,
10
+ )
11
+ from toolbox.keypoint_match.base import (
12
+ KeyPointExtract,
13
+ KeyPointExtractConfig,
14
+ KeyPointMatch,
15
+ KeyPointMatchConfig,
16
+ KeyPointTemplateDetector,
17
+ RegionScorer,
18
+ RegionScorerConfig,
19
+ )
20
+ from toolbox.keypoint_match.detector import SimpleKeyPointTemplateDetector
21
+ from toolbox.keypoint_match.keypoint_extracter.sift import SiftExtractConfig, SiftKeyPointExtract
22
+ from toolbox.keypoint_match.keypoint_extracter.superpoint import SuperPointExtractConfig, SuperPointKeyPointExtract
23
+ from toolbox.keypoint_match.keypoint_match.single_image_match import SingleImageMatcher, SingleImageMatcherConfig
24
+ from toolbox.keypoint_match.keypoint_detector.single_image_detector import (
25
+ SingleImageDetector,
26
+ SingleImageDetectorConfig,
27
+ SingleImageDetectorResult,
28
+ )
29
+ from toolbox.keypoint_match.keypoint_detector.multi_image_detector import (
30
+ MultiImageDetector,
31
+ MultiImageDetectorConfig,
32
+ MultiImageDetectorItem,
33
+ MultiImageDetectorResult,
34
+ )
35
+
36
+ __all__ = [
37
+ "Detection",
38
+ "DetectionResult",
39
+ "ImageArray",
40
+ "KeyPointSet",
41
+ "MatchSet",
42
+ "KeyPointExtract",
43
+ "KeyPointExtractConfig",
44
+ "KeyPointMatch",
45
+ "KeyPointMatchConfig",
46
+ "KeyPointTemplateDetector",
47
+ "RegionScorer",
48
+ "RegionScorerConfig",
49
+ "SimpleKeyPointTemplateDetector",
50
+ "SiftExtractConfig",
51
+ "SiftKeyPointExtract",
52
+ "SuperPointExtractConfig",
53
+ "SuperPointKeyPointExtract",
54
+ "SingleImageMatcher",
55
+ "SingleImageMatcherConfig",
56
+ "SingleImageDetector",
57
+ "SingleImageDetectorConfig",
58
+ "SingleImageDetectorResult",
59
+ "MultiImageDetector",
60
+ "MultiImageDetectorConfig",
61
+ "MultiImageDetectorItem",
62
+ "MultiImageDetectorResult",
63
+ ]
64
+
toolbox/keypoint_match/base.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Any, Mapping, Optional, Sequence, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from toolbox.keypoint_match.types import (
13
+ DetectionResult,
14
+ ImageArray,
15
+ KeyPointSet,
16
+ MatchSet,
17
+ )
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class KeyPointExtractConfig:
22
+ """关键点/描述子提取器的通用配置。"""
23
+
24
+ max_keypoints: Optional[int] = None
25
+ # 允许实现方在 meta 中存放更多参数,例如:nms_radius、score_threshold、scale_pyramid 等
26
+ extra: Optional[Mapping[str, Any]] = None
27
+
28
+
29
+ class KeyPointExtract(ABC):
30
+ """
31
+ 抽象:从一张图片提取关键点与描述子。
32
+
33
+ 适配来源:
34
+ - OpenCV: ORB/SIFT/AKAZE 等
35
+ - Kornia: KeyNetAffNetHardNet、DISK 等
36
+ - SuperPoint/LightGlue 等深度模型
37
+ """
38
+
39
+ def __init__(self, config: Optional[KeyPointExtractConfig] = None):
40
+ self.config = config or KeyPointExtractConfig()
41
+
42
+ @abstractmethod
43
+ def extract(self, image: ImageArray) -> KeyPointSet:
44
+ """输入一张图,输出关键点集合(包含可选描述子)。"""
45
+
46
+ def batch_extract(self, images: Sequence[ImageArray]) -> Sequence[KeyPointSet]:
47
+ return [self.extract(im) for im in images]
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class KeyPointMatchConfig:
52
+ """匹配器的通用配置。"""
53
+
54
+ # 常见:ratio test 的阈值(若实现方使用 KNN)
55
+ ratio: Optional[float] = None
56
+ # 允许实现方控制返回的匹配数量上限
57
+ max_matches: Optional[int] = None
58
+ extra: Optional[Mapping[str, Any]] = None
59
+
60
+
61
+ class KeyPointMatch(ABC):
62
+ """
63
+ 抽象:对两张图(或两组描述子)进行匹配,输出匹配对。
64
+ """
65
+
66
+ def __init__(self, config: Optional[KeyPointMatchConfig] = None):
67
+ self.config = config or KeyPointMatchConfig()
68
+
69
+ @abstractmethod
70
+ def match(self, query: KeyPointSet, train: KeyPointSet) -> MatchSet:
71
+ """
72
+ query: 通常来自“小图/模板”
73
+ train: 通常来自“大图/搜索图”
74
+ """
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class RegionScorerConfig:
79
+ """
80
+ 匹配点聚集成“区域”的通用配置。
81
+
82
+ 说明:你的核心算法描述是“某个区域匹配点特别多 => 目标被找到”,
83
+ 因此我们把“如何聚类/评分/生成 bbox”抽象成 RegionScorer。
84
+ """
85
+
86
+ # 大图上聚类半径(像素)。例如:DBSCAN eps 或网格统计的 cell_size
87
+ radius_px: float = 24.0
88
+ # 认为“找到”的最低匹配点数阈值
89
+ min_match_count: int = 12
90
+ # 最多输出多少个候选区域
91
+ topk: int = 10
92
+ extra: Optional[Mapping[str, Any]] = None
93
+
94
+
95
+ class RegionScorer(ABC):
96
+ """
97
+ 抽象:把匹配关系映射成候选区域(bbox + score)。
98
+
99
+ 输入包含关键点坐标与匹配对,因此实现可以:
100
+ - 做简单的网格投票 / 密度聚类
101
+ - 用单应性 / RANSAC 过滤外点后再成簇
102
+ - 用匹配点的局部一致性做评分
103
+ """
104
+
105
+ def __init__(self, config: Optional[RegionScorerConfig] = None):
106
+ self.config = config or RegionScorerConfig()
107
+
108
+ @abstractmethod
109
+ def score(
110
+ self,
111
+ query: KeyPointSet,
112
+ train: KeyPointSet,
113
+ matches: MatchSet,
114
+ *,
115
+ template_id: Optional[str] = None,
116
+ template_size: Optional[Tuple[int, int]] = None, # (h, w)
117
+ ) -> DetectionResult:
118
+ """
119
+ template_size: 小图大小,便于从匹配点推断 bbox 尺寸(可选)
120
+ """
121
+
122
+
123
+ class KeyPointTemplateDetector(ABC):
124
+ """
125
+ 抽象:把 (提取器 + 匹配器 + 区域评分器) 组合成“模板检测”。
126
+
127
+ 你可以实现一个具体 Detector,将其用于:
128
+ - 多模板检索(小图集合在大图中找出现位置)
129
+ - 单模板定位(某个小图在大图的哪里)
130
+ """
131
+
132
+ @abstractmethod
133
+ def detect(
134
+ self,
135
+ template_image: ImageArray,
136
+ search_image: ImageArray,
137
+ *,
138
+ template_id: Optional[str] = None,
139
+ ) -> DetectionResult:
140
+ ...
141
+
toolbox/keypoint_match/detector.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ from toolbox.keypoint_match.base import KeyPointExtract, KeyPointMatch, KeyPointTemplateDetector, RegionScorer
10
+ from toolbox.keypoint_match.types import DetectionResult, ImageArray, KeyPointSet, MatchSet
11
+
12
+
13
+ @dataclass
14
+ class SimpleKeyPointTemplateDetector(KeyPointTemplateDetector):
15
+ """
16
+ 一个“组合式”的默认 Detector 实现。
17
+
18
+ 说明:
19
+ - 这是一个可直接工作的拼装类:extract(template) + extract(search) + match + score
20
+ - 具体的“找区域”逻辑由 RegionScorer 决定(因此仍保持算法可替换/可扩展)
21
+ """
22
+
23
+ extractor: KeyPointExtract
24
+ matcher: KeyPointMatch
25
+ region_scorer: RegionScorer
26
+
27
+ def detect(
28
+ self,
29
+ template_image: ImageArray,
30
+ search_image: ImageArray,
31
+ *,
32
+ template_id: Optional[str] = None,
33
+ ) -> DetectionResult:
34
+ template_kp: KeyPointSet = self.extractor.extract(template_image)
35
+ search_kp: KeyPointSet = self.extractor.extract(search_image)
36
+ matches: MatchSet = self.matcher.match(template_kp, search_kp)
37
+
38
+ template_size: Optional[Tuple[int, int]] = None
39
+ try:
40
+ h, w = int(template_image.shape[0]), int(template_image.shape[1])
41
+ template_size = (h, w)
42
+ except Exception:
43
+ template_size = None
44
+
45
+ return self.region_scorer.score(
46
+ template_kp,
47
+ search_kp,
48
+ matches,
49
+ template_id=template_id,
50
+ template_size=template_size,
51
+ )
52
+
toolbox/keypoint_match/keypoint_detector/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from toolbox.keypoint_match.keypoint_detector.single_image_detector import (
5
+ SingleImageDetector,
6
+ SingleImageDetectorConfig,
7
+ SingleImageDetectorResult,
8
+ )
9
+ from toolbox.keypoint_match.keypoint_detector.multi_image_detector import (
10
+ MultiImageDetector,
11
+ MultiImageDetectorConfig,
12
+ MultiImageDetectorItem,
13
+ MultiImageDetectorResult,
14
+ )
15
+
16
+ __all__ = [
17
+ "SingleImageDetector",
18
+ "SingleImageDetectorConfig",
19
+ "SingleImageDetectorResult",
20
+ "MultiImageDetector",
21
+ "MultiImageDetectorConfig",
22
+ "MultiImageDetectorItem",
23
+ "MultiImageDetectorResult",
24
+ ]
25
+
toolbox/keypoint_match/keypoint_detector/multi_image_detector.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
10
+
11
+ import cv2
12
+ import numpy as np
13
+
14
+ from project_settings import project_path
15
+ from toolbox.keypoint_match.base import KeyPointExtract
16
+ from toolbox.keypoint_match.keypoint_match.single_image_match import (
17
+ SingleImageMatcher,
18
+ SingleImageMatcherConfig,
19
+ )
20
+ from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class MultiImageDetectorConfig:
25
+ """
26
+ 多模板(多个小图)在同一张大图中的关键点匹配检测配置。
27
+
28
+ - 不做仿射/单应性估计
29
+ - 只画出“大图中匹配得很好的关键点”
30
+ """
31
+
32
+ ratio: float = 0.75
33
+ max_matches: int = 120
34
+ max_keypoints: int = 2000
35
+ point_radius: int = 4
36
+ point_thickness: int = 2
37
+ # 每个模板对应一个颜色(BGR),不足时循环使用
38
+ colors_bgr: Sequence[Tuple[int, int, int]] = (
39
+ (0, 0, 255), # red
40
+ (0, 255, 0), # green
41
+ (255, 0, 0), # blue
42
+ (0, 255, 255), # yellow
43
+ (255, 0, 255), # magenta
44
+ (255, 255, 0), # cyan
45
+ )
46
+ extra: Optional[Dict[str, Any]] = None
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class MultiImageDetectorItem:
51
+ template_id: str
52
+ template_kp: KeyPointSet
53
+ search_kp: KeyPointSet
54
+ matches: MatchSet
55
+ color_bgr: Tuple[int, int, int]
56
+ meta: Optional[Dict[str, Any]] = None
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class MultiImageDetectorResult:
61
+ items: Sequence[MultiImageDetectorItem]
62
+ vis_search: np.ndarray
63
+ meta: Optional[Dict[str, Any]] = None
64
+
65
+
66
+ class MultiImageDetector:
67
+ """
68
+ Multi image detector:给定多个模板小图,在同一张大图中找“匹配得很好的关键点”,并可视化。
69
+ """
70
+
71
+ def __init__(self, extractor: KeyPointExtract, config: Optional[MultiImageDetectorConfig] = None):
72
+ self.extractor = extractor
73
+ self.config = config or MultiImageDetectorConfig()
74
+
75
+ def _color_for_index(self, i: int) -> Tuple[int, int, int]:
76
+ palette = list(self.config.colors_bgr) if self.config.colors_bgr else [(0, 0, 255)]
77
+ return tuple(int(c) for c in palette[i % len(palette)])
78
+
79
+ def _draw_points_on_search(
80
+ self,
81
+ base_image: ImageArray,
82
+ search_kp: KeyPointSet,
83
+ matches: MatchSet,
84
+ color_bgr: Tuple[int, int, int],
85
+ ) -> np.ndarray:
86
+ img = np.asarray(base_image).copy()
87
+ if img.ndim == 2:
88
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
89
+
90
+ if matches.m == 0 or search_kp.n == 0:
91
+ return img
92
+
93
+ idx = matches.train_idx.astype(np.int64)
94
+ idx = idx[(idx >= 0) & (idx < search_kp.n)]
95
+ if idx.size == 0:
96
+ return img
97
+
98
+ pts = search_kp.xy[idx]
99
+ for x, y in pts:
100
+ cv2.circle(
101
+ img,
102
+ center=(int(round(float(x))), int(round(float(y)))),
103
+ radius=int(self.config.point_radius),
104
+ color=color_bgr,
105
+ thickness=int(self.config.point_thickness),
106
+ lineType=cv2.LINE_AA,
107
+ )
108
+ return img
109
+
110
+ def detect(
111
+ self,
112
+ template_images: Sequence[ImageArray],
113
+ search_image: ImageArray,
114
+ *,
115
+ template_ids: Optional[Sequence[str]] = None,
116
+ ) -> MultiImageDetectorResult:
117
+ if template_ids is None:
118
+ template_ids = [f"template_{i}" for i in range(len(template_images))]
119
+ if len(template_ids) != len(template_images):
120
+ raise ValueError("template_ids 的长度必须与 template_images 一致")
121
+
122
+ matcher = SingleImageMatcher(
123
+ extractor=self.extractor,
124
+ config=SingleImageMatcherConfig(
125
+ ratio=float(self.config.ratio),
126
+ max_matches=int(self.config.max_matches),
127
+ ),
128
+ )
129
+
130
+ # 为了避免重复提取大图特征点:这里直接复用 matcher 的实现会重复提取,
131
+ # 但为了接口简单先保持这样。后续需要性能时,可扩展 matcher 支持传入已提取的 search_kp。
132
+ items: List[MultiImageDetectorItem] = []
133
+ vis = np.asarray(search_image).copy()
134
+ if vis.ndim == 2:
135
+ vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
136
+
137
+ for i, (tid, tmpl) in enumerate(zip(template_ids, template_images)):
138
+ template_kp, search_kp, matches = matcher.match(tmpl, search_image, template_id=str(tid))
139
+ color = self._color_for_index(i)
140
+
141
+ # 叠加绘制
142
+ vis = self._draw_points_on_search(vis, search_kp, matches, color_bgr=color)
143
+
144
+ items.append(
145
+ MultiImageDetectorItem(
146
+ template_id=str(tid),
147
+ template_kp=template_kp,
148
+ search_kp=search_kp,
149
+ matches=matches,
150
+ color_bgr=color,
151
+ meta={
152
+ "kp_template": int(template_kp.n),
153
+ "kp_search": int(search_kp.n),
154
+ "match_count": int(matches.m),
155
+ },
156
+ )
157
+ )
158
+
159
+ meta: Dict[str, Any] = {
160
+ "template_count": int(len(template_images)),
161
+ "ratio": float(self.config.ratio),
162
+ "max_matches": int(self.config.max_matches),
163
+ }
164
+ if self.config.extra:
165
+ meta["extra"] = dict(self.config.extra)
166
+
167
+ return MultiImageDetectorResult(items=items, vis_search=vis, meta=meta)
168
+
169
+
170
+ def get_args():
171
+ parser = argparse.ArgumentParser()
172
+ parser.add_argument(
173
+ "--template_paths",
174
+ type=str,
175
+ nargs="+",
176
+ default=[
177
+ (project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
178
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
179
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
180
+ (project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
181
+ ],
182
+ )
183
+ parser.add_argument(
184
+ "--search_path",
185
+ type=str,
186
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
187
+ default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard2.jpg").as_posix(),
188
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard3.jpg").as_posix(),
189
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard4.jpg").as_posix(),
190
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard6.jpg").as_posix(),
191
+ )
192
+ parser.add_argument("--ratio", type=float, default=0.90)
193
+ parser.add_argument("--max_matches", type=int, default=120)
194
+ parser.add_argument("--max_keypoints", type=int, default=2000)
195
+ return parser.parse_args()
196
+
197
+
198
+ def main():
199
+ from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
200
+
201
+ args = get_args()
202
+
203
+ search = cv2.imread(args.search_path)
204
+ if search is None:
205
+ raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
206
+
207
+ templates: List[np.ndarray] = []
208
+ template_ids: List[str] = []
209
+ for p in args.template_paths:
210
+ img = cv2.imread(p)
211
+ if img is None:
212
+ raise FileNotFoundError(f"无法读取模板图: {p}")
213
+ templates.append(img)
214
+ template_ids.append(Path(p).stem)
215
+
216
+ extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
217
+ detector = MultiImageDetector(
218
+ extractor=extractor,
219
+ config=MultiImageDetectorConfig(
220
+ ratio=float(args.ratio),
221
+ max_matches=int(args.max_matches),
222
+ max_keypoints=int(args.max_keypoints),
223
+ ),
224
+ )
225
+
226
+ result = detector.detect(templates, search, template_ids=template_ids)
227
+
228
+ # 在窗口标题里打印每个模板的匹配数量
229
+ stat = ", ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
230
+ title = f"multi_image_detector | {stat}"
231
+ cv2.imshow(title, result.vis_search)
232
+ cv2.waitKey(0)
233
+ cv2.destroyAllWindows()
234
+ return
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()
239
+
toolbox/keypoint_match/keypoint_detector/single_image_detector.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ import cv2
11
+ import numpy as np
12
+
13
+ from project_settings import project_path
14
+ from toolbox.keypoint_match.base import KeyPointExtract
15
+ from toolbox.keypoint_match.keypoint_match.single_image_match import (
16
+ SingleImageMatcher,
17
+ SingleImageMatcherConfig,
18
+ )
19
+ from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class SingleImageDetectorConfig:
24
+ """
25
+ 单目标小图在大图中的“关键点匹配检测”配置。
26
+
27
+ 注意:这里不做仿射/单应性估计,只保留“匹配得很好的点”,并把这些点在大图上画出来。
28
+ """
29
+
30
+ ratio: float = 0.75
31
+ max_matches: int = 120
32
+ max_keypoints: int = 2000
33
+ # 画点参数
34
+ point_radius: int = 4
35
+ point_thickness: int = 2
36
+ color_bgr: Tuple[int, int, int] = (0, 0, 255)
37
+ # 是否同时在模板图上也画出参与匹配的点(便于对照)
38
+ draw_on_template: bool = False
39
+ extra: Optional[Dict[str, Any]] = None
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class SingleImageDetectorResult:
44
+ template_kp: KeyPointSet
45
+ search_kp: KeyPointSet
46
+ matches: MatchSet
47
+ vis_search: np.ndarray
48
+ vis_template: Optional[np.ndarray] = None
49
+ meta: Optional[Dict[str, Any]] = None
50
+
51
+
52
+ class SingleImageDetector:
53
+ """
54
+ Single image detector:输入小图与大图,输出“匹配到的大图关键点可视化”。
55
+
56
+ - 初始化时传入关键点提取器(SIFT/SuperPoint 等)
57
+ - 内部复用 `SingleImageMatcher` 做匹配
58
+ """
59
+
60
+ def __init__(self, extractor: KeyPointExtract, config: Optional[SingleImageDetectorConfig] = None):
61
+ self.extractor = extractor
62
+ self.config = config or SingleImageDetectorConfig()
63
+
64
+ def detect(self, template_image: ImageArray, search_image: ImageArray, *, template_id: Optional[str] = None) -> SingleImageDetectorResult:
65
+ matcher = SingleImageMatcher(
66
+ extractor=self.extractor,
67
+ config=SingleImageMatcherConfig(
68
+ ratio=float(self.config.ratio),
69
+ max_matches=int(self.config.max_matches),
70
+ ),
71
+ )
72
+ template_kp, search_kp, matches = matcher.match(template_image, search_image, template_id=template_id)
73
+
74
+ vis_search = self.draw_matched_keypoints_on_search(search_image, search_kp, matches)
75
+ vis_template = None
76
+ if self.config.draw_on_template:
77
+ vis_template = self.draw_matched_keypoints_on_template(template_image, template_kp, matches)
78
+
79
+ meta = {
80
+ "template_id": template_id,
81
+ "kp_template": int(template_kp.n),
82
+ "kp_search": int(search_kp.n),
83
+ "match_count": int(matches.m),
84
+ "ratio": float(self.config.ratio),
85
+ }
86
+ if self.config.extra:
87
+ meta["extra"] = dict(self.config.extra)
88
+
89
+ return SingleImageDetectorResult(
90
+ template_kp=template_kp,
91
+ search_kp=search_kp,
92
+ matches=matches,
93
+ vis_search=vis_search,
94
+ vis_template=vis_template,
95
+ meta=meta,
96
+ )
97
+
98
+ def draw_matched_keypoints_on_search(self, search_image: ImageArray, search_kp: KeyPointSet, matches: MatchSet) -> np.ndarray:
99
+ img = np.asarray(search_image).copy()
100
+ if img.ndim == 2:
101
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
102
+
103
+ if matches.m == 0 or search_kp.n == 0:
104
+ return img
105
+
106
+ # 在大图上画出被匹配到的 train_idx 对应关键点
107
+ idx = matches.train_idx.astype(np.int64)
108
+ idx = idx[(idx >= 0) & (idx < search_kp.n)]
109
+ if idx.size == 0:
110
+ return img
111
+
112
+ pts = search_kp.xy[idx]
113
+ for x, y in pts:
114
+ cv2.circle(
115
+ img,
116
+ center=(int(round(float(x))), int(round(float(y)))),
117
+ radius=int(self.config.point_radius),
118
+ color=tuple(int(c) for c in self.config.color_bgr),
119
+ thickness=int(self.config.point_thickness),
120
+ lineType=cv2.LINE_AA,
121
+ )
122
+ return img
123
+
124
+ def draw_matched_keypoints_on_template(self, template_image: ImageArray, template_kp: KeyPointSet, matches: MatchSet) -> np.ndarray:
125
+ img = np.asarray(template_image).copy()
126
+ if img.ndim == 2:
127
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
128
+
129
+ if matches.m == 0 or template_kp.n == 0:
130
+ return img
131
+
132
+ idx = matches.query_idx.astype(np.int64)
133
+ idx = idx[(idx >= 0) & (idx < template_kp.n)]
134
+ if idx.size == 0:
135
+ return img
136
+
137
+ pts = template_kp.xy[idx]
138
+ for x, y in pts:
139
+ cv2.circle(
140
+ img,
141
+ center=(int(round(float(x))), int(round(float(y)))),
142
+ radius=max(2, int(self.config.point_radius) - 1),
143
+ color=(0, 255, 0),
144
+ thickness=int(self.config.point_thickness),
145
+ lineType=cv2.LINE_AA,
146
+ )
147
+ return img
148
+
149
+
150
+ def get_args():
151
+ parser = argparse.ArgumentParser()
152
+ parser.add_argument(
153
+ "--template_path",
154
+ type=str,
155
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png"),
156
+ )
157
+ parser.add_argument(
158
+ "--search_path",
159
+ type=str,
160
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png"),
161
+ default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard1.jpg"),
162
+ )
163
+ parser.add_argument("--ratio", type=float, default=0.75)
164
+ parser.add_argument("--max_matches", type=int, default=120)
165
+ parser.add_argument("--max_keypoints", type=int, default=2000)
166
+ parser.add_argument("--draw_on_template", action="store_true")
167
+ return parser.parse_args()
168
+
169
+
170
+ def main():
171
+ from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
172
+
173
+ args = get_args()
174
+
175
+ template = cv2.imread(args.template_path)
176
+ search = cv2.imread(args.search_path)
177
+ if template is None:
178
+ raise FileNotFoundError(f"无法读取模板图: {args.template_path}")
179
+ if search is None:
180
+ raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
181
+
182
+ extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
183
+ detector = SingleImageDetector(
184
+ extractor=extractor,
185
+ config=SingleImageDetectorConfig(
186
+ ratio=float(args.ratio),
187
+ max_matches=int(args.max_matches),
188
+ max_keypoints=int(args.max_keypoints),
189
+ draw_on_template=bool(args.draw_on_template),
190
+ ),
191
+ )
192
+
193
+ result = detector.detect(template, search, template_id="roller")
194
+
195
+ title = f"single_image_detector | matches={result.matches.m}"
196
+ cv2.imshow(title, result.vis_search)
197
+ if result.vis_template is not None:
198
+ cv2.imshow("template_matched_keypoints", result.vis_template)
199
+ cv2.waitKey(0)
200
+ cv2.destroyAllWindows()
201
+ return
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
206
+
toolbox/keypoint_match/keypoint_extracter/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from toolbox.keypoint_match.keypoint_extracter.sift import SiftExtractConfig, SiftKeyPointExtract
5
+ from toolbox.keypoint_match.keypoint_extracter.superpoint import SuperPointExtractConfig, SuperPointKeyPointExtract
6
+
7
+ __all__ = [
8
+ "SiftExtractConfig",
9
+ "SiftKeyPointExtract",
10
+ "SuperPointExtractConfig",
11
+ "SuperPointKeyPointExtract",
12
+ ]
13
+
toolbox/keypoint_match/keypoint_extracter/sift.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Optional
8
+
9
+ import numpy as np
10
+ import cv2
11
+
12
+ from project_settings import project_path
13
+ from toolbox.keypoint_match.base import KeyPointExtract, KeyPointExtractConfig
14
+ from toolbox.keypoint_match.types import ImageArray, KeyPointSet
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class SiftExtractConfig(KeyPointExtractConfig):
19
+ """
20
+ OpenCV SIFT 的常用参数封装。
21
+
22
+ 说明:OpenCV 的 SIFT_create 参数较多,这里保留最常见的几项;其余可放到 extra。
23
+ """
24
+
25
+ n_features: int = 0
26
+ n_octave_layers: int = 3
27
+ contrast_threshold: float = 0.04
28
+ edge_threshold: float = 10.0
29
+ sigma: float = 1.6
30
+
31
+
32
+ class SiftKeyPointExtract(KeyPointExtract):
33
+ """基于 OpenCV SIFT 的特征点/描述子提取器。"""
34
+
35
+ def __init__(self, config: Optional[SiftExtractConfig] = None):
36
+ super().__init__(config=config or SiftExtractConfig())
37
+ self.config: SiftExtractConfig
38
+ self._sift = self._create_sift()
39
+
40
+ def _create_sift(self):
41
+ try:
42
+ import cv2 # 延迟导入,避免无 OpenCV 时影响其它模块
43
+ except Exception as e: # pragma: no cover
44
+ raise ImportError("使用 SiftKeyPointExtract 需要先安装 opencv-python 或 opencv-contrib-python") from e
45
+
46
+ if not hasattr(cv2, "SIFT_create"):
47
+ raise RuntimeError(
48
+ "当前 OpenCV 不包含 SIFT(可能缺少 contrib 模块)。"
49
+ "请安装/替换为 `opencv-contrib-python`。"
50
+ )
51
+
52
+ cfg = self.config
53
+ return cv2.SIFT_create(
54
+ nfeatures=int(cfg.n_features),
55
+ nOctaveLayers=int(cfg.n_octave_layers),
56
+ contrastThreshold=float(cfg.contrast_threshold),
57
+ edgeThreshold=float(cfg.edge_threshold),
58
+ sigma=float(cfg.sigma),
59
+ )
60
+
61
+ @staticmethod
62
+ def _to_gray_u8(image: ImageArray) -> np.ndarray:
63
+ """
64
+ SIFT 在 OpenCV 中通常使用 8-bit 灰度图。
65
+ - 输入 uint8: 直接处理(若为彩色则转灰)
66
+ - 输入 float: 若在 [0,1],缩放到 [0,255];否则 clip 到 [0,255]
67
+ """
68
+ import cv2
69
+
70
+ img = np.asarray(image)
71
+ if img.ndim == 3 and img.shape[2] >= 3:
72
+ # OpenCV 读图一般是 BGR;这里不强制颜色空间,仅做灰度化
73
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
74
+ elif img.ndim != 2:
75
+ raise ValueError(f"image 需要是 HxW 或 HxWxC,但得到 {img.shape=}")
76
+
77
+ if img.dtype == np.uint8:
78
+ return img
79
+
80
+ img_f = img.astype(np.float32, copy=False)
81
+ if img_f.size > 0 and img_f.max() <= 1.0:
82
+ img_f = img_f * 255.0
83
+ img_u8 = np.clip(img_f, 0.0, 255.0).astype(np.uint8)
84
+ return img_u8
85
+
86
+ def extract(self, image: ImageArray) -> KeyPointSet:
87
+ import cv2
88
+
89
+ gray = self._to_gray_u8(image)
90
+ keypoints, descriptors = self._sift.detectAndCompute(gray, mask=None)
91
+
92
+ if keypoints is None or len(keypoints) == 0:
93
+ empty_xy = np.zeros((0, 2), dtype=np.float32)
94
+ return KeyPointSet(xy=empty_xy, descriptors=None, scores=None, meta={"backend": "opencv_sift"})
95
+
96
+ xy = np.array([kp.pt for kp in keypoints], dtype=np.float32) # (x,y)
97
+ scores = np.array([kp.response for kp in keypoints], dtype=np.float32)
98
+
99
+ if descriptors is not None:
100
+ descriptors = np.asarray(descriptors)
101
+
102
+ # 若设置了 max_keypoints,则按 response 排序截断(SIFT 本身也可能受 nfeatures 影响)
103
+ max_kp = self.config.max_keypoints
104
+ if max_kp is not None and xy.shape[0] > int(max_kp):
105
+ idx = np.argsort(-scores)[: int(max_kp)]
106
+ xy = xy[idx]
107
+ scores = scores[idx]
108
+ if descriptors is not None:
109
+ descriptors = descriptors[idx]
110
+
111
+ meta: Dict[str, Any] = {
112
+ "backend": "opencv_sift",
113
+ "n_features": int(self.config.n_features),
114
+ "n_octave_layers": int(self.config.n_octave_layers),
115
+ "contrast_threshold": float(self.config.contrast_threshold),
116
+ "edge_threshold": float(self.config.edge_threshold),
117
+ "sigma": float(self.config.sigma),
118
+ }
119
+ if self.config.extra:
120
+ meta["extra"] = dict(self.config.extra)
121
+
122
+ return KeyPointSet(
123
+ xy=xy,
124
+ descriptors=descriptors,
125
+ scores=scores,
126
+ meta=meta,
127
+ )
128
+
129
+
130
+
131
+ def get_args():
132
+ parser = argparse.ArgumentParser()
133
+ parser.add_argument(
134
+ "--image_path",
135
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
136
+ # default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png").as_posix(),
137
+ type=str
138
+ )
139
+ parser.add_argument("--max_keypoints", type=int, default=8000)
140
+
141
+ args = parser.parse_args()
142
+ return args
143
+
144
+
145
+ def main():
146
+ args = get_args()
147
+
148
+ image = cv2.imread(args.image_path)
149
+ if image is None:
150
+ raise FileNotFoundError(f"无法读取图片: {args.image_path}")
151
+
152
+ extractor = SiftKeyPointExtract(
153
+ SiftExtractConfig(
154
+ max_keypoints=int(args.max_keypoints),
155
+ )
156
+ )
157
+ kp_set = extractor.extract(image)
158
+
159
+ cv2_keypoints = [
160
+ cv2.KeyPoint(
161
+ x=float(x),
162
+ y=float(y),
163
+ size=7,
164
+ response=float(score) if kp_set.scores is not None else 0.0,
165
+ )
166
+ for (x, y), score in zip(kp_set.xy, kp_set.scores if kp_set.scores is not None else np.zeros((kp_set.n,)))
167
+ ]
168
+
169
+ image_with_kp = cv2.drawKeypoints(
170
+ image,
171
+ cv2_keypoints,
172
+ None,
173
+ color=(0, 0, 255),
174
+ flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
175
+ )
176
+
177
+ cv2.imshow("sift_keypoints", image_with_kp)
178
+ cv2.waitKey(0)
179
+ cv2.destroyAllWindows()
180
+ return
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
toolbox/keypoint_match/keypoint_extracter/superpoint.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, Optional
10
+
11
+ import numpy as np
12
+
13
+ from project_settings import project_path
14
+ from toolbox.keypoint_match.base import KeyPointExtract, KeyPointExtractConfig
15
+ from toolbox.keypoint_match.types import ImageArray, KeyPointSet
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class SuperPointExtractConfig(KeyPointExtractConfig):
20
+ model_name: str = "magic-leap-community/superpoint"
21
+ model_cache_dir: str = (project_path / "../../hf_hub_models").as_posix()
22
+ device: str = "cpu" # "cpu" / "cuda"
23
+ # 设置 HuggingFace 镜像(与示例保持一致)
24
+ hf_endpoint: Optional[str] = "https://hf-mirror.com"
25
+
26
+
27
+ class SuperPointKeyPointExtract(KeyPointExtract):
28
+ """
29
+ 基于 transformers 的 SuperPoint 特征点/描述子提取器。
30
+
31
+ 参考:examples/keypoints/superpoint/test.py
32
+ """
33
+
34
+ def __init__(self, config: Optional[SuperPointExtractConfig] = None):
35
+ super().__init__(config=config or SuperPointExtractConfig())
36
+ self.config: SuperPointExtractConfig
37
+
38
+ if self.config.hf_endpoint:
39
+ os.environ.setdefault("HF_ENDPOINT", str(self.config.hf_endpoint))
40
+
41
+ self._processor = None
42
+ self._model = None
43
+
44
+ def _lazy_init(self):
45
+ if self._processor is not None and self._model is not None:
46
+ return
47
+
48
+ import torch
49
+ from transformers import AutoImageProcessor
50
+ from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection
51
+
52
+ self._processor = AutoImageProcessor.from_pretrained(
53
+ pretrained_model_name_or_path=self.config.model_name,
54
+ cache_dir=self.config.model_cache_dir,
55
+ )
56
+ self._model = SuperPointForKeypointDetection.from_pretrained(
57
+ pretrained_model_name_or_path=self.config.model_name,
58
+ cache_dir=self.config.model_cache_dir,
59
+ )
60
+ self._model.eval()
61
+
62
+ device = torch.device(self.config.device)
63
+ self._model.to(device)
64
+
65
+ @staticmethod
66
+ def _to_pil_rgb(image: ImageArray):
67
+ from PIL import Image
68
+ import cv2
69
+
70
+ img = np.asarray(image)
71
+ if img.ndim == 2:
72
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
73
+ elif img.ndim == 3 and img.shape[2] >= 3:
74
+ # 默认按 OpenCV 的 BGR 输入处理
75
+ img = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2RGB)
76
+ else:
77
+ raise ValueError(f"image 需要是 HxW 或 HxWxC,但得到 {img.shape=}")
78
+
79
+ if img.dtype != np.uint8:
80
+ img_f = img.astype(np.float32, copy=False)
81
+ if img_f.size > 0 and img_f.max() <= 1.0:
82
+ img_f = img_f * 255.0
83
+ img = np.clip(img_f, 0.0, 255.0).astype(np.uint8)
84
+
85
+ return Image.fromarray(img).convert("RGB")
86
+
87
+ def extract(self, image: ImageArray) -> KeyPointSet:
88
+ import torch
89
+
90
+ self._lazy_init()
91
+ assert self._processor is not None
92
+ assert self._model is not None
93
+
94
+ pil = self._to_pil_rgb(image)
95
+ inputs = self._processor(pil, return_tensors="pt")
96
+
97
+ device = next(self._model.parameters()).device
98
+ inputs = {k: v.to(device) for k, v in inputs.items()}
99
+
100
+ with torch.no_grad():
101
+ output = self._model(**inputs)
102
+
103
+ image_size = (pil.height, pil.width)
104
+ processed = self._processor.post_process_keypoint_detection(
105
+ output,
106
+ [image_size],
107
+ )
108
+
109
+ keypoints = processed[0]["keypoints"] # [N,2] (x,y)
110
+ scores = processed[0]["scores"] # [N]
111
+ descriptors = processed[0]["descriptors"] # [N,D]
112
+
113
+ keypoints_np = keypoints.detach().cpu().numpy().astype(np.float32)
114
+ scores_np = scores.detach().cpu().numpy().astype(np.float32)
115
+ desc_np = descriptors.detach().cpu().numpy().astype(np.float32)
116
+
117
+ # 统一按 scores 截断到 max_keypoints
118
+ max_kp = self.config.max_keypoints
119
+ if max_kp is not None and keypoints_np.shape[0] > int(max_kp):
120
+ idx = np.argsort(-scores_np)[: int(max_kp)]
121
+ keypoints_np = keypoints_np[idx]
122
+ scores_np = scores_np[idx]
123
+ desc_np = desc_np[idx]
124
+
125
+ meta: Dict[str, Any] = {
126
+ "backend": "transformers_superpoint",
127
+ "model_name": str(self.config.model_name),
128
+ "device": str(self.config.device),
129
+ }
130
+ if self.config.extra:
131
+ meta["extra"] = dict(self.config.extra)
132
+
133
+ return KeyPointSet(
134
+ xy=keypoints_np,
135
+ descriptors=desc_np,
136
+ scores=scores_np,
137
+ meta=meta,
138
+ )
139
+
140
+
141
+ def main():
142
+ import cv2
143
+
144
+ parser = argparse.ArgumentParser()
145
+ parser.add_argument(
146
+ "--image_path",
147
+ type=str,
148
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
149
+ )
150
+ parser.add_argument("--device", type=str, default="cpu")
151
+ parser.add_argument("--max_keypoints", type=int, default=2000)
152
+ args = parser.parse_args()
153
+
154
+ image = cv2.imread(args.image_path)
155
+ if image is None:
156
+ raise FileNotFoundError(f"无法读取图片: {args.image_path}")
157
+
158
+ extractor = SuperPointKeyPointExtract(
159
+ SuperPointExtractConfig(
160
+ device=str(args.device),
161
+ max_keypoints=int(args.max_keypoints),
162
+ )
163
+ )
164
+ kp_set = extractor.extract(image)
165
+
166
+ cv2_keypoints = [
167
+ cv2.KeyPoint(
168
+ x=float(x),
169
+ y=float(y),
170
+ size=7,
171
+ response=float(score),
172
+ )
173
+ for (x, y), score in zip(kp_set.xy, kp_set.scores if kp_set.scores is not None else np.zeros((kp_set.n,), dtype=np.float32))
174
+ ]
175
+
176
+ image_with_kp = cv2.drawKeypoints(
177
+ image,
178
+ cv2_keypoints,
179
+ None,
180
+ color=(0, 0, 255),
181
+ flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
182
+ )
183
+ cv2.imshow(f"superpoint_keypoints | n={kp_set.n}", image_with_kp)
184
+ cv2.waitKey(0)
185
+ cv2.destroyAllWindows()
186
+ return
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
191
+
toolbox/keypoint_match/keypoint_match/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from toolbox.keypoint_match.keypoint_match.single_image_match import SingleImageMatcher, SingleImageMatcherConfig
5
+
6
+ __all__ = [
7
+ "SingleImageMatcher",
8
+ "SingleImageMatcherConfig",
9
+ ]
10
+
toolbox/keypoint_match/keypoint_match/single_image_match.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+ from project_settings import project_path
13
+ from toolbox.keypoint_match.base import KeyPointExtract
14
+ from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
15
+ from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class SingleImageMatcherConfig:
20
+ """
21
+ 单模板图 vs 单大图 的匹配配置。
22
+ """
23
+
24
+ # KNN ratio test 阈值(Lowe's ratio test),越小越严格
25
+ ratio: float = 0.75
26
+ # 最多保留多少条匹配(按 distance 从小到大截断)
27
+ max_matches: Optional[int] = 500
28
+ # 是否做 mutual check(A->B 与 B->A 互相最近邻一致才保留)
29
+ mutual_check: bool = False
30
+ extra: Optional[Dict[str, Any]] = None
31
+
32
+
33
+ class SingleImageMatcher:
34
+ """
35
+ 匹配类:输入一张目标小图 + 一张包含目标的大图,输出匹配对。
36
+
37
+ - 初始化时传入关键点提取器(例如 `SiftKeyPointExtract` / SuperPoint 提取器等)
38
+ - `match(...)` 会提取两张图的 KeyPointSet,然后用 OpenCV BFMatcher 做匹配并输出 MatchSet
39
+ """
40
+
41
+ def __init__(self, extractor: KeyPointExtract, config: Optional[SingleImageMatcherConfig] = None):
42
+ self.extractor = extractor
43
+ self.config = config or SingleImageMatcherConfig()
44
+
45
+ @staticmethod
46
+ def _infer_norm_type(desc: np.ndarray) -> int:
47
+ import cv2
48
+
49
+ # ORB/BRIEF 等通常是 uint8 的二进制描述子;SIFT/SuperPoint 等通常是 float32
50
+ if desc.dtype == np.uint8:
51
+ return cv2.NORM_HAMMING
52
+ return cv2.NORM_L2
53
+
54
+ def _bfmatcher(self, query_desc: np.ndarray):
55
+ import cv2
56
+
57
+ norm = self._infer_norm_type(query_desc)
58
+ # mutual_check=True 时,OpenCV 的 crossCheck 只能用于 match()(不能用于 knnMatch)
59
+ return cv2.BFMatcher(normType=norm, crossCheck=False)
60
+
61
+ def _knn_ratio_match(self, bf, query_desc: np.ndarray, train_desc: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
62
+ """
63
+ 返回 (query_idx, train_idx, distance) 的一维数组
64
+ """
65
+ knn = bf.knnMatch(query_desc, train_desc, k=2)
66
+ q_idx = []
67
+ t_idx = []
68
+ dist = []
69
+ ratio = float(self.config.ratio)
70
+ for pair in knn:
71
+ if len(pair) < 2:
72
+ continue
73
+ m, n = pair[0], pair[1]
74
+ if m.distance < ratio * n.distance:
75
+ q_idx.append(int(m.queryIdx))
76
+ t_idx.append(int(m.trainIdx))
77
+ dist.append(float(m.distance))
78
+
79
+ if len(q_idx) == 0:
80
+ return (
81
+ np.zeros((0,), dtype=np.int64),
82
+ np.zeros((0,), dtype=np.int64),
83
+ np.zeros((0,), dtype=np.float32),
84
+ )
85
+
86
+ q_idx_arr = np.asarray(q_idx, dtype=np.int64)
87
+ t_idx_arr = np.asarray(t_idx, dtype=np.int64)
88
+ dist_arr = np.asarray(dist, dtype=np.float32)
89
+
90
+ # 按 distance 从小到大排序,并截断
91
+ order = np.argsort(dist_arr)
92
+ if self.config.max_matches is not None:
93
+ order = order[: int(self.config.max_matches)]
94
+ return q_idx_arr[order], t_idx_arr[order], dist_arr[order]
95
+
96
+ @staticmethod
97
+ def _mutual_filter(
98
+ q_to_t: Tuple[np.ndarray, np.ndarray, np.ndarray],
99
+ t_to_q: Tuple[np.ndarray, np.ndarray, np.ndarray],
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ q_to_t: (q_idx, t_idx, dist) from query->train
103
+ t_to_q: (t_idx, q_idx, dist) from train->query (注意顺序)
104
+ 保留互为最近邻/通过 ratio 的匹配对。
105
+ """
106
+ q_idx, t_idx, dist = q_to_t
107
+ t_idx2, q_idx2, _ = t_to_q
108
+ if q_idx.size == 0 or t_idx2.size == 0:
109
+ return q_idx, t_idx, dist
110
+
111
+ # 用集合做互相包含过滤
112
+ pairs_qt = {(int(q), int(t)) for q, t in zip(q_idx.tolist(), t_idx.tolist())}
113
+ pairs_tq = {(int(q), int(t)) for t, q in zip(t_idx2.tolist(), q_idx2.tolist())} # 统一成 (q,t)
114
+ keep_pairs = pairs_qt.intersection(pairs_tq)
115
+ if not keep_pairs:
116
+ return (
117
+ np.zeros((0,), dtype=np.int64),
118
+ np.zeros((0,), dtype=np.int64),
119
+ np.zeros((0,), dtype=np.float32),
120
+ )
121
+
122
+ keep_mask = np.array([(int(q), int(t)) in keep_pairs for q, t in zip(q_idx, t_idx)], dtype=bool)
123
+ return q_idx[keep_mask], t_idx[keep_mask], dist[keep_mask]
124
+
125
+ def match(
126
+ self,
127
+ template_image: ImageArray,
128
+ search_image: ImageArray,
129
+ *,
130
+ template_id: Optional[str] = None,
131
+ ) -> Tuple[KeyPointSet, KeyPointSet, MatchSet]:
132
+ """
133
+ 返回:(template_kp, search_kp, matches)
134
+ """
135
+ template_kp = self.extractor.extract(template_image)
136
+ search_kp = self.extractor.extract(search_image)
137
+
138
+ q_desc = np.asarray(template_kp.descriptors) if template_kp.descriptors is not None else np.zeros((0, 0), dtype=np.float32)
139
+ t_desc = np.asarray(search_kp.descriptors) if search_kp.descriptors is not None else np.zeros((0, 0), dtype=np.float32)
140
+ if q_desc.size == 0 or t_desc.size == 0:
141
+ matches = MatchSet(
142
+ query_idx=np.zeros((0,), dtype=np.int64),
143
+ train_idx=np.zeros((0,), dtype=np.int64),
144
+ distance=np.zeros((0,), dtype=np.float32),
145
+ meta={"backend": "opencv_bf_knn_ratio", "template_id": template_id},
146
+ )
147
+ return template_kp, search_kp, matches
148
+
149
+ bf = self._bfmatcher(q_desc)
150
+ q_to_t = self._knn_ratio_match(bf, q_desc, t_desc)
151
+
152
+ if self.config.mutual_check:
153
+ bf2 = self._bfmatcher(t_desc)
154
+ t_to_q = self._knn_ratio_match(bf2, t_desc, q_desc)
155
+ q_idx, t_idx, dist = self._mutual_filter(q_to_t, t_to_q)
156
+ else:
157
+ q_idx, t_idx, dist = q_to_t
158
+
159
+ meta: Dict[str, Any] = {
160
+ "backend": "opencv_bf_knn_ratio",
161
+ "ratio": float(self.config.ratio),
162
+ "mutual_check": bool(self.config.mutual_check),
163
+ "template_id": template_id,
164
+ }
165
+ if self.config.max_matches is not None:
166
+ meta["max_matches"] = int(self.config.max_matches)
167
+ if self.config.extra:
168
+ meta["extra"] = dict(self.config.extra)
169
+
170
+ matches = MatchSet(
171
+ query_idx=q_idx,
172
+ train_idx=t_idx,
173
+ distance=dist,
174
+ meta=meta,
175
+ )
176
+ return template_kp, search_kp, matches
177
+
178
+
179
+ def _to_cv2_keypoints(kp_set: KeyPointSet):
180
+ import cv2
181
+
182
+ scores = (
183
+ kp_set.scores
184
+ if kp_set.scores is not None
185
+ else np.zeros((kp_set.n,), dtype=np.float32)
186
+ )
187
+ return [
188
+ cv2.KeyPoint(
189
+ x=float(x),
190
+ y=float(y),
191
+ size=7,
192
+ response=float(score),
193
+ )
194
+ for (x, y), score in zip(kp_set.xy, scores)
195
+ ]
196
+
197
+
198
+ def _to_cv2_matches(match_set: MatchSet):
199
+ import cv2
200
+
201
+ dist = (
202
+ match_set.distance
203
+ if match_set.distance is not None
204
+ else np.zeros((match_set.m,), dtype=np.float32)
205
+ )
206
+ return [
207
+ cv2.DMatch(
208
+ _queryIdx=int(q),
209
+ _trainIdx=int(t),
210
+ _imgIdx=0,
211
+ _distance=float(d),
212
+ )
213
+ for q, t, d in zip(match_set.query_idx, match_set.train_idx, dist)
214
+ ]
215
+
216
+
217
+ def get_args():
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument(
220
+ "--template_path",
221
+ type=str,
222
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png"),
223
+ )
224
+ parser.add_argument(
225
+ "--search_path",
226
+ type=str,
227
+ default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png"),
228
+ )
229
+ parser.add_argument("--ratio", type=float, default=0.75)
230
+ parser.add_argument("--max_matches", type=int, default=80)
231
+ parser.add_argument("--max_keypoints", type=int, default=2000)
232
+ args = parser.parse_args()
233
+
234
+ return args
235
+
236
+
237
+ def main():
238
+ args = get_args()
239
+
240
+ template = cv2.imread(args.template_path)
241
+ search = cv2.imread(args.search_path)
242
+ if template is None:
243
+ raise FileNotFoundError(f"无法读取模板图: {args.template_path}")
244
+ if search is None:
245
+ raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
246
+
247
+ extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
248
+ matcher = SingleImageMatcher(
249
+ extractor=extractor,
250
+ config=SingleImageMatcherConfig(
251
+ ratio=float(args.ratio),
252
+ max_matches=int(args.max_matches),
253
+ ),
254
+ )
255
+
256
+ template_kp, search_kp, matches = matcher.match(template, search)
257
+
258
+ template_cv2_kp = _to_cv2_keypoints(template_kp)
259
+ search_cv2_kp = _to_cv2_keypoints(search_kp)
260
+ cv2_matches = _to_cv2_matches(matches)
261
+
262
+ vis = cv2.drawMatches(
263
+ template,
264
+ template_cv2_kp,
265
+ search,
266
+ search_cv2_kp,
267
+ cv2_matches,
268
+ None,
269
+ flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
270
+ )
271
+
272
+ title = f"single_match | kp_t={template_kp.n} kp_s={search_kp.n} matches={matches.m}"
273
+ cv2.imshow(title, vis)
274
+ cv2.waitKey(0)
275
+ cv2.destroyAllWindows()
276
+ return
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
toolbox/keypoint_match/types.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any, Mapping, Optional, Sequence, Tuple
8
+
9
+ import numpy as np
10
+
11
+
12
+ ArrayF32 = np.ndarray
13
+ ArrayU8 = np.ndarray
14
+ ImageArray = np.ndarray # HxWxC or HxW, dtype不限(uint8/float32均可)
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class KeyPointSet:
19
+ """
20
+ 统一表示一张图片上的关键点与(可选)描述子。
21
+
22
+ 约定:
23
+ - xy: shape [N, 2],列为 (x, y),float32/float64
24
+ - descriptors: shape [N, D](可选)。例如 ORB: uint8;SIFT/SuperPoint: float32
25
+ - scores: shape [N](可选),越大表示点越“强”
26
+ """
27
+
28
+ xy: ArrayF32
29
+ descriptors: Optional[np.ndarray] = None
30
+ scores: Optional[ArrayF32] = None
31
+ meta: Optional[Mapping[str, Any]] = None
32
+
33
+ def __post_init__(self) -> None:
34
+ xy = np.asarray(self.xy)
35
+ if xy.ndim != 2 or xy.shape[1] != 2:
36
+ raise ValueError(f"KeyPointSet.xy 必须是 [N,2],但得到 {xy.shape=}")
37
+ if self.descriptors is not None:
38
+ desc = np.asarray(self.descriptors)
39
+ if desc.ndim != 2 or desc.shape[0] != xy.shape[0]:
40
+ raise ValueError(
41
+ "KeyPointSet.descriptors 必须是 [N,D] 且与 xy 的 N 一致,"
42
+ f"但得到 {desc.shape=} vs {xy.shape=}"
43
+ )
44
+ if self.scores is not None:
45
+ sc = np.asarray(self.scores)
46
+ if sc.ndim != 1 or sc.shape[0] != xy.shape[0]:
47
+ raise ValueError(
48
+ "KeyPointSet.scores 必须是 [N] 且与 xy 的 N 一致,"
49
+ f"但得到 {sc.shape=} vs {xy.shape=}"
50
+ )
51
+
52
+ @property
53
+ def n(self) -> int:
54
+ return int(np.asarray(self.xy).shape[0])
55
+
56
+ @property
57
+ def has_descriptors(self) -> bool:
58
+ return self.descriptors is not None
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class MatchSet:
63
+ """
64
+ 统一表示两组关键点之间的匹配关系。
65
+
66
+ - query_idx: 对应“模板/小图”关键点的索引,shape [M]
67
+ - train_idx: 对应“大图/搜索图”关键点的索引,shape [M]
68
+ - distance: 该匹配的距离/相似度度量,shape [M](越小越相似是最常见约定)
69
+ """
70
+
71
+ query_idx: np.ndarray
72
+ train_idx: np.ndarray
73
+ distance: Optional[np.ndarray] = None
74
+ meta: Optional[Mapping[str, Any]] = None
75
+
76
+ def __post_init__(self) -> None:
77
+ qi = np.asarray(self.query_idx)
78
+ ti = np.asarray(self.train_idx)
79
+ if qi.ndim != 1 or ti.ndim != 1 or qi.shape[0] != ti.shape[0]:
80
+ raise ValueError(f"MatchSet 索引必须是同长度的一维数组,但得到 {qi.shape=} {ti.shape=}")
81
+ if self.distance is not None:
82
+ dist = np.asarray(self.distance)
83
+ if dist.ndim != 1 or dist.shape[0] != qi.shape[0]:
84
+ raise ValueError(f"MatchSet.distance 必须是 [M],但得到 {dist.shape=} vs {qi.shape=}")
85
+
86
+ @property
87
+ def m(self) -> int:
88
+ return int(np.asarray(self.query_idx).shape[0])
89
+
90
+
91
+ @dataclass(frozen=True)
92
+ class Detection:
93
+ """
94
+ 一次“模板在大图中被找到”的候选结果。
95
+ """
96
+
97
+ bbox_xyxy: Tuple[float, float, float, float] # (x1,y1,x2,y2)
98
+ score: float
99
+ match_count: int
100
+ template_id: Optional[str] = None
101
+ meta: Optional[Mapping[str, Any]] = None
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class DetectionResult:
106
+ """
107
+ 一次检测(一个模板对一张大图)输出的结果集合。
108
+ """
109
+
110
+ detections: Sequence[Detection]
111
+ meta: Optional[Mapping[str, Any]] = None
112
+
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+
6
+ class Command(object):
7
+ custom_command = [
8
+ "cd"
9
+ ]
10
+
11
+ @staticmethod
12
+ def _get_cmd(command):
13
+ command = str(command).strip()
14
+ if command == "":
15
+ return None
16
+ cmd_and_args = command.split(sep=" ")
17
+ cmd = cmd_and_args[0]
18
+ args = " ".join(cmd_and_args[1:])
19
+ return cmd, args
20
+
21
+ @classmethod
22
+ def popen(cls, command):
23
+ cmd, args = cls._get_cmd(command)
24
+ if cmd in cls.custom_command:
25
+ method = getattr(cls, cmd)
26
+ return method(args)
27
+ else:
28
+ resp = os.popen(command)
29
+ result = resp.read()
30
+ resp.close()
31
+ return result
32
+
33
+ @classmethod
34
+ def cd(cls, args):
35
+ if args.startswith("/"):
36
+ os.chdir(args)
37
+ else:
38
+ pwd = os.getcwd()
39
+ path = os.path.join(pwd, args)
40
+ os.chdir(path)
41
+
42
+ @classmethod
43
+ def system(cls, command):
44
+ return os.system(command)
45
+
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ def ps_ef_grep(keyword: str):
51
+ cmd = "ps -ef | grep {}".format(keyword)
52
+ rows = Command.popen(cmd)
53
+ rows = str(rows).split("\n")
54
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
55
+ return rows
56
+
57
+
58
+ if __name__ == "__main__":
59
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))